提交 669816e2 编写于 作者: M Megvii Engine Team

feat(dnn): warpperspective support multi src input

GitOrigin-RevId: 8a4789852e6df47b5b44ac3e2fa2999d4e0ab5d6
上级 33b27be8
...@@ -16,10 +16,18 @@ protected: ...@@ -16,10 +16,18 @@ protected:
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
check_layout_fwd(src, mat, {}, dst); check_layout_fwd(src, mat, {}, dst);
} }
void check_layout_fwd(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& dst) {
check_layout_fwd(srcs, mat, {}, dst);
}
void check_layout_fwd( void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst); const TensorLayout& mat_idx, const TensorLayout& dst);
void check_layout_fwd(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst);
std::string param_msg() const; std::string param_msg() const;
int get_real_coord(int p, int len); int get_real_coord(int p, int len);
}; };
...@@ -49,6 +57,12 @@ public: ...@@ -49,6 +57,12 @@ public:
exec(src, mat, {}, dst, workspace); exec(src, mat, {}, dst, workspace);
} }
void exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
exec(srcs, mat, {}, dst, workspace);
}
/** /**
* \p src should have batch size m, and \p mat and \p mat_idx should * \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range * both have batch size n. Each item in \p mat_idx must be in the range
...@@ -62,15 +76,30 @@ public: ...@@ -62,15 +76,30 @@ public:
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
return get_workspace_in_bytes(src, mat, {}, dst); return get_workspace_in_bytes(src, mat, {}, dst);
} }
size_t get_workspace_in_bytes(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& dst) {
return get_workspace_in_bytes(srcs, mat, {}, dst);
}
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) = 0; const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) = 0;
protected: protected:
void check_exec( void check_exec(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
...@@ -81,6 +110,10 @@ protected: ...@@ -81,6 +110,10 @@ protected:
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst, const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes); size_t workspace_in_bytes);
void check_exec_allow_nhwc_mat_idx(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using WarpPerspective = WarpPerspectiveForward; using WarpPerspective = WarpPerspectiveForward;
......
...@@ -22,4 +22,11 @@ bool warp::is_dnn_available( ...@@ -22,4 +22,11 @@ bool warp::is_dnn_available(
return imode == param::WarpAffine::InterpolationMode::LINEAR; return imode == param::WarpAffine::InterpolationMode::LINEAR;
} }
bool warp::is_dnn_available(
const TensorLayoutArray& /*src*/, const TensorLayout& /*mat*/,
const TensorLayout& /*dst*/, param::WarpAffine::InterpolationMode imode,
param::WarpAffine::Format /*format*/) {
return imode == param::WarpAffine::InterpolationMode::LINEAR;
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -90,6 +90,10 @@ bool is_dnn_available( ...@@ -90,6 +90,10 @@ bool is_dnn_available(
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format);
bool is_dnn_available(
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&,
param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format);
using namespace megcv; using namespace megcv;
using IMode = InterpolationMode; using IMode = InterpolationMode;
using BMode = BorderMode; using BMode = BorderMode;
......
...@@ -3,7 +3,97 @@ ...@@ -3,7 +3,97 @@
#include "src/common/utils.h" #include "src/common/utils.h"
namespace megdnn { namespace megdnn {
void WarpPerspectiveBase::check_layout_fwd(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) {
megdnn_assert(srcs.size() > 0);
auto s = srcs.front();
for (auto&& src : srcs) {
megdnn_assert_contiguous(src);
megdnn_assert(src.dtype == s.dtype);
megdnn_assert(src.ndim == s.ndim);
megdnn_assert(src.shape[0] == 1);
for (size_t i = 0; i < s.ndim; i++) {
megdnn_assert(src.shape[i] == s.shape[i]);
}
megdnn_assert(src.format == s.format);
}
megdnn_assert_contiguous(mat);
megdnn_assert_contiguous(dst);
auto errmsg = [&]() {
std::string msg = "{";
for (auto&& src : srcs) {
msg.append(megdnn_layout_msg(src) + ", ");
}
return msg + "} " + megdnn_layout_msg(mat) + ", " + megdnn_layout_msg(mat_idx) +
", " + megdnn_layout_msg(dst) + ", " + param_msg();
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(
param().format == param::WarpPerspective::Format::NHWC ||
param().format == param::WarpPerspective::Format::NCHW);
megdnn_assert(s.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str());
megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
if (mat_idx.ndim) {
megdnn_assert(
mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, "%s",
errmsg().c_str());
megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
megdnn_assert_contiguous(mat_idx);
} else {
megdnn_assert(s.shape[0] * srcs.size() == dst.shape[0], "%s", errmsg().c_str());
}
megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
if (s.format == dst.format && dst.dtype == s.dtype) {
if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(
s.dtype.enumv() == DTypeEnum::Float32 ||
DNN_FLOAT16_SELECT(
(s.dtype.enumv() == DTypeEnum::Float16 ||
s.dtype.enumv() == DTypeEnum::BFloat16),
false),
"WarpPerspective multi src NCHW input dtype should be "
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") ".");
megdnn_assert(
(s.dtype.category() == DTypeCategory::FLOAT &&
(s.dtype == mat.dtype || mat.dtype.enumv() == DTypeEnum::Float32)),
"The input to WarpPerspective multi src is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given.",
mat.dtype.name());
megdnn_assert(s.shape[1] == dst.shape[1], "%s", errmsg().c_str());
megdnn_assert(
param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(
param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(
param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWC);
megdnn_assert(
s.dtype.enumv() == DTypeEnum::Float32 ||
DNN_FLOAT16_SELECT(
(s.dtype.enumv() == DTypeEnum::Float16 ||
s.dtype.enumv() == DTypeEnum::BFloat16),
false),
"WarpPerspective multi src NHWC input dtype should be "
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") ".");
megdnn_assert(s.shape[3] == dst.shape[3], "%s", errmsg().c_str());
}
} else {
megdnn_assert(
0,
"WarpPerspective multi src only support format NHWC/NCHW, dtype "
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") ".");
}
}
void WarpPerspectiveBase::check_layout_fwd( void WarpPerspectiveBase::check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx,
const TensorLayout& dst) { const TensorLayout& dst) {
...@@ -295,6 +385,19 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( ...@@ -295,6 +385,19 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
} }
} }
void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes) {
check_layout_fwd(srcs, mat, mat_idx, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, mat, mat_idx, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
if (param().format != Param::Format::NHWC &&
param().format != Param::Format::NCHW) {
megdnn_assert(!mat_idx.ndim, "mat_idx not supported for current format");
}
}
void WarpPerspectiveBackwardData::check_exec( void WarpPerspectiveBackwardData::check_exec(
const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff, const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes) { const TensorLayout& grad, size_t workspace_in_bytes) {
......
...@@ -17,6 +17,13 @@ void forward_proxy( ...@@ -17,6 +17,13 @@ void forward_proxy(
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info, ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream); void* error_tracker, cudaStream_t stream);
template <typename ctype>
void forward_proxy_multi_src(
bool is_nhwc, const ctype** srcs, const float* mat, const int* mat_idx,
ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);
template <typename ctype, int pack_c> template <typename ctype, int pack_c>
void forward_proxy_nhwc_bit4( void forward_proxy_nhwc_bit4(
const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC, const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC,
......
...@@ -143,6 +143,34 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( ...@@ -143,6 +143,34 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
return {ptr, std::move(sizes)}; return {ptr, std::move(sizes)};
} }
WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
void* ptr, const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) const {
MEGDNN_MARK_USED_VAR(mat_idx);
SmallVector<size_t> sizes;
TensorLayoutArray fsrcs = srcs;
TensorLayout fmat = mat;
TensorLayout fdst = dst;
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
for (auto&& fsrc : fsrcs) {
get_workspace(fsrc);
}
get_workspace(fmat);
get_workspace(fdst);
sizes.push_back(sizeof(dt_float32*) * srcs.size());
if (param().format == param::WarpPerspective::Format::NHWC) {
//! use double for the workspace dtype as float may cause
//! accuracy problems
sizes.push_back(mat.total_nr_elems() * sizeof(double));
}
return {ptr, std::move(sizes)};
}
void WarpPerspectiveForwardImpl::exec( void WarpPerspectiveForwardImpl::exec(
_megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx, _megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx,
_megdnn_tensor_out sdst, _megdnn_workspace sworkspace) { _megdnn_tensor_out sdst, _megdnn_workspace sworkspace) {
...@@ -453,6 +481,124 @@ void WarpPerspectiveForwardImpl::exec( ...@@ -453,6 +481,124 @@ void WarpPerspectiveForwardImpl::exec(
} }
} }
void WarpPerspectiveForwardImpl::exec(
_megdnn_in const TensorNDArray& ssrcs, _megdnn_tensor_in smat,
_megdnn_tensor_in smat_idx, _megdnn_tensor_out sdst,
_megdnn_workspace sworkspace) {
TensorLayoutArray ssrcs_layout;
for (auto&& s : ssrcs) {
ssrcs_layout.push_back(s.layout);
}
check_exec_allow_nhwc_mat_idx(
ssrcs_layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size);
TensorNDArray srcs = ssrcs;
TensorND mat = smat;
TensorND mat_idx = smat_idx;
TensorND dst = sdst;
Param::Format inner_format = param().format;
auto bundle = get_workspace_bundle(
sworkspace.raw_ptr, ssrcs_layout, smat.layout, smat_idx.layout,
sdst.layout);
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
concrete_handle(this->handle()), &bundle);
if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
for (size_t i = 0; i < ssrcs.size(); i++) {
ctypecvt.src_to_comp_type(ssrcs[i], srcs[i]);
}
ctypecvt.src_to_comp_type(smat, mat).src_to_comp_type(sdst, dst);
}
{
auto stream = cuda_stream(this->handle());
bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
TensorND src = srcs.front();
megdnn_assert(warp::is_dnn_available(
ssrcs_layout, mat.layout, dst.layout, param().imode, inner_format));
size_t C, IH, IW, OH, OW;
if (is_nhwc) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[1];
OW = dst.layout.shape[2];
} else {
megdnn_assert(
inner_format == param::WarpPerspective::Format::NCHW,
"invalid warp_perspective format");
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
}
megdnn_assert(
param().imode == Param::InterpolationMode::LINEAR,
"unsupported interpolation mode form NCHW format");
auto bval = param().border_val;
auto bmode = warp_perspective::get_bmode(param().bmode);
if (src.layout.dtype == dst.layout.dtype) {
if (src.layout.dtype == dtype::Float32{}) {
SmallVector<size_t> workspace_sizes{sizeof(dt_float32*) * srcs.size()};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes();
void* workspace_cpu_raw = malloc(total_workspace_size);
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
auto srcs_cpu = static_cast<const dt_float32**>(workspace_cpu.get(0));
size_t i =
is_nhwc ? bundle.nr_workspace() - 2 : bundle.nr_workspace() - 1;
auto srcs_gpu = static_cast<const dt_float32**>(bundle.get(i));
for (size_t i = 0; i < srcs.size(); ++i) {
srcs_cpu[i] = srcs[i].ptr<dt_float32>();
}
cuda_check(cudaMemcpyAsync(
bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0),
cudaMemcpyHostToDevice, stream));
cuda_check(cudaStreamAddCallback(
stream, callback_free, static_cast<void*>(workspace_cpu_raw),
0));
warp_perspective::forward_proxy_multi_src(
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float32>(), srcs.size(), mat.layout[0], C, IH, IW,
OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16
SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes();
void* workspace_cpu_raw = malloc(total_workspace_size);
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
auto srcs_cpu = static_cast<const dt_float16**>(workspace_cpu.get(0));
auto srcs_gpu = static_cast<const dt_float16**>(bundle.get(0));
for (size_t i = 0; i < srcs.size(); ++i) {
srcs_cpu[i] = srcs[i].ptr<dt_float16>();
}
cuda_check(cudaMemcpyAsync(
bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0),
cudaMemcpyHostToDevice, stream));
cuda_check(cudaStreamAddCallback(
stream, callback_free, static_cast<void*>(workspace_cpu_raw),
0));
warp_perspective::forward_proxy_multi_src(
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float16>(), srcs.size(), mat.layout[0], C, IH, IW,
OH, OW, static_cast<dt_float16>(bval), bmode,
async_error_info(handle()), m_error_tracker, stream);
#endif
}
} else {
megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
}
}
if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.comp_to_dst_type(dst, sdst);
}
}
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -47,11 +47,16 @@ struct CtypeHelper<dt_quint4> { ...@@ -47,11 +47,16 @@ struct CtypeHelper<dt_quint4> {
template <typename ctype> template <typename ctype>
struct DirectSrcVisitor { struct DirectSrcVisitor {
const void* ptr; const void* ptr;
const void** ptrs;
__device__ __forceinline__ const ctype* get(int batch, int im_size) { __device__ __forceinline__ const ctype* get(int batch, int im_size) {
return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8);
} }
__device__ __forceinline__ const ctype* get(int batch) {
return (ctype*)(ptrs[batch]);
}
void move_batch(size_t batch, size_t im_size) { void move_batch(size_t batch, size_t im_size) {
ptr = (char*)ptr + batch * im_size * CtypeHelper<ctype>::bit_width / 8; ptr = (char*)ptr + batch * im_size * CtypeHelper<ctype>::bit_width / 8;
} }
...@@ -60,6 +65,7 @@ struct DirectSrcVisitor { ...@@ -60,6 +65,7 @@ struct DirectSrcVisitor {
template <typename ctype> template <typename ctype>
struct IndexedSrcVisitor { struct IndexedSrcVisitor {
const void* ptr; const void* ptr;
const void** ptrs;
const int* idx; const int* idx;
int N_SRC; int N_SRC;
...@@ -79,9 +85,58 @@ struct IndexedSrcVisitor { ...@@ -79,9 +85,58 @@ struct IndexedSrcVisitor {
return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8);
} }
__device__ __forceinline__ const ctype* get(int batch) {
int orig_batch = batch;
batch = idx[batch];
if (batch < 0 || batch >= N_SRC) {
set_async_error_info(
error_info, error_tracker,
"mat_idx out of bound: mat_idx[%d]=%d src_batch=%d", orig_batch,
batch, N_SRC);
batch = 0;
}
return (ctype*)(ptrs[batch]);
}
void move_batch(size_t batch, size_t) { idx += batch; } void move_batch(size_t batch, size_t) { idx += batch; }
}; };
template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_multi_src(
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW) {
Getter getter;
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = srcs.get(blockIdx.z);
dst += blockIdx.z * C * OH * OW;
mat += blockIdx.z * 3 * 3;
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = getter(floor(iw) + 0, IW);
int iw1 = getter(floor(iw) + 1, IW);
int ih0 = getter(floor(ih) + 0, IH);
int ih1 = getter(floor(ih) + 1, IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
for (int c = 0; c < C; ++c) {
dst[oh * OW + ow] = output_converter(
sptr[ih0 * IW + iw0] * nalpha * nbeta +
sptr[ih0 * IW + iw1] * nalpha * pbeta +
sptr[ih1 * IW + iw0] * palpha * nbeta +
sptr[ih1 * IW + iw1] * palpha * pbeta);
sptr += IH * IW;
dst += OH * OW;
}
}
}
template < template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter> typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general( __global__ void kern_general(
...@@ -261,6 +316,47 @@ __global__ void kern_general_nchw64( ...@@ -261,6 +316,47 @@ __global__ void kern_general_nchw64(
} }
} }
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border_multi_src(
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, ctype bval) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = srcs.get(blockIdx.z);
dst += blockIdx.z * C * OH * OW;
mat += blockIdx.z * 3 * 3;
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = floor(iw) + 0;
int iw1 = floor(iw) + 1;
int ih0 = floor(ih) + 0;
int ih1 = floor(ih) + 1;
bool okw0 = (iw0 >= 0 && iw0 < IW);
bool okw1 = (iw1 >= 0 && iw1 < IW);
bool okh0 = (ih0 >= 0 && ih0 < IH);
bool okh1 = (ih1 >= 0 && ih1 < IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
for (int c = 0; c < C; ++c) {
ctype v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval);
ctype v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval);
ctype v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval);
ctype v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval);
ctype val = output_converter(
v00 * nalpha * nbeta + v01 * nalpha * pbeta + v10 * palpha * nbeta +
v11 * palpha * pbeta);
dst[oh * OW + ow] = val;
sptr += IH * IW;
dst += OH * OW;
}
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter> template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border( __global__ void kern_const_border(
SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C,
...@@ -553,6 +649,51 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> { ...@@ -553,6 +649,51 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> {
} }
}; };
template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter,
int pack_c>
__global__ void kern_general_nhwc_multi_src(
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW) {
Getter getter;
OutputConverter output_converter;
constexpr int bit_width = CtypeHelper<ctype>::bit_width;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = srcs.get(blockIdx.z);
dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8);
mat += blockIdx.z * 3 * 3;
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = getter(floor(iw) + 0, IW);
int iw1 = getter(floor(iw) + 1, IW);
int ih0 = getter(floor(ih) + 0, IH);
int ih1 = getter(floor(ih) + 1, IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
float w00 = nalpha * nbeta;
float w01 = nalpha * pbeta;
float w10 = palpha * nbeta;
float w11 = palpha * pbeta;
const char* src_ptr0 = (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8;
const char* src_ptr1 = (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8;
const char* src_ptr2 = (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8;
const char* src_ptr3 = (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8;
char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8;
for (int c = 0; c < C; c += pack_c) {
KernCoreNHWC<ctype, OutputConverter, pack_c>::func(
dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, c * bit_width / 8,
w00, w01, w10, w11, output_converter, true, true, true, true,
(ctype)0);
}
}
}
template < template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter,
int pack_c> int pack_c>
...@@ -598,6 +739,58 @@ __global__ void kern_general_nhwc( ...@@ -598,6 +739,58 @@ __global__ void kern_general_nhwc(
} }
} }
template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter,
int pack_c>
__global__ void kern_general_nhwc_const_multi_src(
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, ctype bval) {
Getter getter;
OutputConverter output_converter;
constexpr int bit_width = CtypeHelper<ctype>::bit_width;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = srcs.get(blockIdx.z);
dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8);
mat += blockIdx.z * 3 * 3;
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = getter(floor(iw) + 0, IW);
int iw1 = getter(floor(iw) + 1, IW);
int ih0 = getter(floor(ih) + 0, IH);
int ih1 = getter(floor(ih) + 1, IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
float w00 = nalpha * nbeta;
float w01 = nalpha * pbeta;
float w10 = palpha * nbeta;
float w11 = palpha * pbeta;
const char* src_ptr0 = (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8;
const char* src_ptr1 = (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8;
const char* src_ptr2 = (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8;
const char* src_ptr3 = (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8;
char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8;
bool okw0 = (iw0 >= 0 && iw0 < IW);
bool okw1 = (iw1 >= 0 && iw1 < IW);
bool okh0 = (ih0 >= 0 && ih0 < IH);
bool okh1 = (ih1 >= 0 && ih1 < IH);
bool src0_ok = okh0 && okw0;
bool src1_ok = okh0 && okw1;
bool src2_ok = okh1 && okw0;
bool src3_ok = okh1 && okw1;
for (int c = 0; c < C; c += pack_c) {
KernCoreNHWC<ctype, OutputConverter, pack_c>::func(
dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, c * bit_width / 8,
w00, w01, w10, w11, output_converter, src0_ok, src1_ok, src2_ok,
src3_ok, bval);
}
}
}
template < template <
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter,
int pack_c> int pack_c>
...@@ -650,6 +843,73 @@ __global__ void kern_general_nhwc_const( ...@@ -650,6 +843,73 @@ __global__ void kern_general_nhwc_const(
} }
} }
template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor_multi_src(
bool is_nhwc, SrcVisitor srcs, const float* mat, ctype* dst, int N, int C,
int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode,
cudaStream_t stream) {
constexpr int pack_c = 1;
const int BY = 16, BX = 32;
#define DISPATCH(Getter) \
do { \
if (is_nhwc) { \
kern_general_nhwc_multi_src< \
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>, \
pack_c><<<blocks, threads, 0, stream>>>( \
srcs, mat, dst, C, IH, IW, OH, OW); \
} else { \
kern_general_multi_src< \
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>> \
<<<blocks, threads, 0, stream>>>( \
srcs, mat, dst, C, IH, IW, OH, OW); \
} \
} while (0)
const int max_batch_size = 65535;
while (N) {
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
switch (bmode) {
case BORDER_REPLICATE:
DISPATCH(ReplicateGetter);
break;
case BORDER_REFLECT:
DISPATCH(ReflectGetter);
break;
case BORDER_REFLECT_101:
DISPATCH(Reflect101Getter);
break;
case BORDER_WRAP:
DISPATCH(WrapGetter);
break;
#undef DISPATCH
case BORDER_CONSTANT:
if (is_nhwc) {
kern_general_nhwc_const_multi_src<
ctype, ConstGetter, SrcVisitor,
rounding::RoundingConverter<ctype>, pack_c>
<<<blocks, threads, 0, stream>>>(
srcs, mat, dst, C, IH, IW, OH, OW, bval);
} else {
kern_const_border_multi_src<
ctype, SrcVisitor, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
srcs, mat, dst, C, IH, IW, OH, OW, bval);
}
break;
default:
break;
}
N -= curr_batch_size;
srcs.move_batch(curr_batch_size, C * IH * IW);
mat += curr_batch_size * 3 * 3;
dst += curr_batch_size * C * OH * OW;
}
}
template <typename ctype, typename SrcVisitor> template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor( void dispatch_with_visitor(
bool is_nhwc, SrcVisitor src, const float* mat, ctype* dst, int N, int C, bool is_nhwc, SrcVisitor src, const float* mat, ctype* dst, int N, int C,
...@@ -1534,6 +1794,33 @@ void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw( ...@@ -1534,6 +1794,33 @@ void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw(
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
namespace warp_perspective { namespace warp_perspective {
template <typename ctype>
void forward_proxy_multi_src(
bool is_nhwc, const ctype** srcs, const float* mat, const int* mat_idx,
ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream) {
if (mat_idx) {
IndexedSrcVisitor<ctype> visitor;
visitor.ptrs = reinterpret_cast<const void**>(srcs);
visitor.ptr = srcs;
visitor.idx = mat_idx;
visitor.N_SRC = N_SRC;
visitor.error_info = error_info;
visitor.error_tracker = error_tracker;
dispatch_with_visitor_multi_src(
is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, bmode,
stream);
} else {
DirectSrcVisitor<ctype> visitor;
visitor.ptrs = reinterpret_cast<const void**>(srcs);
visitor.ptr = srcs;
dispatch_with_visitor_multi_src(
is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, bmode,
stream);
}
after_kernel_launch();
}
template <typename ctype> template <typename ctype>
void forward_proxy( void forward_proxy(
...@@ -1643,6 +1930,17 @@ INST(dt_float16) ...@@ -1643,6 +1930,17 @@ INST(dt_float16)
INST(int8_t) INST(int8_t)
#undef INST #undef INST
#define INST(ctype) \
template void forward_proxy_multi_src( \
bool, const ctype**, const float*, const int*, ctype*, int, int, int, int, \
int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \
cudaStream_t);
INST(float)
#ifndef MEGDNN_DISABLE_FLOAT16
INST(dt_float16)
#endif
#undef INST
#define INST(ctype) \ #define INST(ctype) \
template void forward_proxy_nchw4( \ template void forward_proxy_nchw4( \
const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \ const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \
......
...@@ -15,12 +15,22 @@ public: ...@@ -15,12 +15,22 @@ public:
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
void exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) override { const TensorLayout& mat_idx, const TensorLayout& dst) override {
return get_workspace_bundle(nullptr, src, mat, mat_idx, dst) return get_workspace_bundle(nullptr, src, mat, mat_idx, dst)
.total_size_in_bytes(); .total_size_in_bytes();
} }
size_t get_workspace_in_bytes(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) override {
return get_workspace_bundle(nullptr, srcs, mat, mat_idx, dst)
.total_size_in_bytes();
}
void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } void set_error_tracker(void* tracker) override { m_error_tracker = tracker; }
...@@ -28,6 +38,9 @@ private: ...@@ -28,6 +38,9 @@ private:
WorkspaceBundle get_workspace_bundle( WorkspaceBundle get_workspace_bundle(
void* ptr, const TensorLayout& src, const TensorLayout& mat, void* ptr, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) const; const TensorLayout& mat_idx, const TensorLayout& dst) const;
WorkspaceBundle get_workspace_bundle(
void* ptr, const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) const;
}; };
class WarpPerspectiveBackwardDataImpl final : public WarpPerspectiveBackwardData { class WarpPerspectiveBackwardDataImpl final : public WarpPerspectiveBackwardData {
......
...@@ -51,6 +51,56 @@ size_t WarpPerspectiveImpl::get_workspace_in_bytes( ...@@ -51,6 +51,56 @@ size_t WarpPerspectiveImpl::get_workspace_in_bytes(
} }
} }
size_t WarpPerspectiveImpl::get_workspace_in_bytes(
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&,
const TensorLayout& dst) {
if (param().format == param::WarpPerspective::Format::NCHW) {
size_t OH = dst.shape[2], OW = dst.shape[3];
return get_bundle(OH, OW).total_size_in_bytes();
} else {
return 0;
}
}
void WarpPerspectiveImpl::exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
TensorLayoutArray srcs_layout;
for (auto&& src : srcs) {
srcs_layout.push_back(src.layout);
}
check_exec_allow_nhwc_mat_idx(
srcs_layout, mat.layout, mat_idx.layout, dst.layout, workspace.size);
size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
->megcore_dispatcher()
->nr_threads();
if (param().format == Format::NCHW && nr_threads == 1_z) {
#define cb(dt, ct, mct) \
case DTypeTrait<dt>::enumv: { \
auto kparam = KernParam<ct, mct>::from_tensors( \
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \
dst, workspace); \
MIDOUT_BEGIN(megdnn_fallback_warpperspective, midout_iv(0), dt, ct, mct) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback_multi_src(kparam)); \
return; \
} \
MIDOUT_END(); \
}
switch (srcs.front().layout.dtype.enumv()) {
cb(dtype::Float32, float, float);
DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, float));
default:
megdnn_throw(ssprintf(
"Unsupported input DType in "
"WarpPerspective: %s",
srcs.front().layout.dtype.name())
.c_str());
}
#undef cb
}
naive::WarpPerspectiveForwardImpl::exec(srcs, mat, mat_idx, dst, workspace);
}
void WarpPerspectiveImpl::exec( void WarpPerspectiveImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_in dst, _megdnn_workspace workspace) { _megdnn_tensor_in dst, _megdnn_workspace workspace) {
...@@ -95,6 +145,69 @@ void WarpPerspectiveImpl::exec( ...@@ -95,6 +145,69 @@ void WarpPerspectiveImpl::exec(
naive::WarpPerspectiveForwardImpl::exec(src, mat, mat_idx, dst, workspace); naive::WarpPerspectiveForwardImpl::exec(src, mat, mat_idx, dst, workspace);
} }
template <typename ctype, typename mtype>
void WarpPerspectiveImpl::kern_fallback_multi_src(
const KernParam<ctype, mtype>& kern_param) {
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
// cause error if accidentally used
sptr = nullptr;
mptr = nullptr;
dptr = nullptr;
MEGDNN_MARK_USED_VAR(sptr);
MEGDNN_MARK_USED_VAR(mptr);
MEGDNN_MARK_USED_VAR(dptr);
MEGDNN_MARK_USED_VAR(border_val);
MEGDNN_MARK_USED_VAR(IH);
MEGDNN_MARK_USED_VAR(IW);
KernParam<ctype, mtype> sub_param = kern_param;
sub_param.n_src = 1;
sub_param.n_mat = 1;
sub_param.midx_ptr = RefPtr();
sub_param.src_ptr = RefPtr(kern_param.srcs_ptr.front().get_ptr());
sub_param.mat_ptr = RefPtr(kern_param.mat_ptr.get_ptr());
sub_param.dst_ptr = RefPtr(kern_param.dst_ptr.get_ptr());
sub_param.srcs_ptr = kern_param.srcs_ptr;
rep(n, N_MAT) {
if (midx_ptr) {
size_t idx = midx_ptr[n];
megdnn_assert(
idx < N_SRC, "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu",
n, idx, N_SRC);
sub_param.src_ptr.reset(
static_cast<ctype*>(kern_param.srcs_ptr[idx].get_ptr()));
} else if (n) {
sub_param.src_ptr.reset(
static_cast<ctype*>(kern_param.srcs_ptr[n].get_ptr()));
}
if (is_resize_optimizable(static_cast<mtype*>(sub_param.mat_ptr.get_ptr()))) {
if (bmode == BorderMode::CONSTANT) {
MIDOUT_BEGIN(
megdnn_fallback_warpperspective, midout_iv(1), midout_iv(true),
ctype, mtype) {
kern_resize<true, ctype, mtype>(sub_param);
}
MIDOUT_END();
} else {
MIDOUT_BEGIN(
megdnn_fallback_warpperspective, midout_iv(1), midout_iv(false),
ctype, mtype) {
kern_resize<false, ctype, mtype>(sub_param);
}
MIDOUT_END();
}
} else {
MIDOUT_BEGIN(megdnn_fallback_warpperspective, midout_iv(2), ctype, mtype) {
rep(oh, OH) kern_naive<ctype, mtype>(sub_param, oh);
}
MIDOUT_END();
}
sub_param.mat_ptr += 3 * 3 * sizeof(mtype);
sub_param.dst_ptr += C * OH * OW * sizeof(ctype);
}
}
template <typename ctype, typename mtype> template <typename ctype, typename mtype>
void WarpPerspectiveImpl::kern_fallback(const KernParam<ctype, mtype>& kern_param) { void WarpPerspectiveImpl::kern_fallback(const KernParam<ctype, mtype>& kern_param) {
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
......
...@@ -9,14 +9,24 @@ protected: ...@@ -9,14 +9,24 @@ protected:
template <typename ctype, typename mtype> template <typename ctype, typename mtype>
void kern_fallback(const KernParam<ctype, mtype>& kern_param); void kern_fallback(const KernParam<ctype, mtype>& kern_param);
template <typename ctype, typename mtype>
void kern_fallback_multi_src(const KernParam<ctype, mtype>& kern_param);
public: public:
using naive::WarpPerspectiveForwardImpl::WarpPerspectiveForwardImpl; using naive::WarpPerspectiveForwardImpl::WarpPerspectiveForwardImpl;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) override; const TensorLayout& mat_idx, const TensorLayout& dst) override;
size_t get_workspace_in_bytes(
const TensorLayoutArray& srcs, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) override;
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
void exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
private: private:
template <typename ctype> template <typename ctype>
......
...@@ -14,6 +14,119 @@ MIDOUT_DECL(megdnn_naive_warpperspective) ...@@ -14,6 +14,119 @@ MIDOUT_DECL(megdnn_naive_warpperspective)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
template <typename ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive_multi_src(
const KernParam<ctype, mtype>& kern_param, size_t task_id) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(0)) {
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
MEGDNN_MARK_USED_VAR(N_MAT);
//! strides of C, H, W on src and dst
size_t sstrd[3], dstrd[3];
auto set_sstrd = [&](size_t s0, size_t s1, size_t s2) {
sstrd[0] = s0;
sstrd[1] = s1;
sstrd[2] = s2;
};
auto set_dstrd = [&](size_t s0, size_t s1, size_t s2) {
dstrd[0] = s0;
dstrd[1] = s1;
dstrd[2] = s2;
};
switch (kern_param.format) {
case Format::NCHW:
set_sstrd(IH * IW, IW, 1);
set_dstrd(OH * OW, OW, 1);
break;
case Format::NHWC:
set_sstrd(1, IW * C, C);
set_dstrd(1, OW * C, C);
break;
default:
megdnn_throw("bad format");
}
auto visit_src = [&sptr, sstrd](size_t c, int h, int w) -> float {
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w];
};
auto visit_src_bd = [&sptr, sstrd, border_val](
size_t c, int h, int w) -> float {
if (h != -1 && w != -1) {
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w];
} else
return border_val;
};
auto visit_dst = [&dptr, dstrd](size_t c, int h, int w) -> ctype& {
return dptr[dstrd[0] * c + dstrd[1] * h + dstrd[2] * w];
};
rounding::RoundingConverter<ctype> output_converter;
sptr = static_cast<const ctype*>(kern_param.srcs_ptr.front().get_ptr());
size_t n = task_id / OH;
size_t oh = task_id % OH;
mptr = mptr + n * 3 * 3;
dptr = dptr + n * C * OH * OW;
if (midx_ptr) {
size_t idx = midx_ptr[n];
megdnn_assert(
idx < N_SRC, "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu",
n, idx, N_SRC);
sptr = sptrs[idx];
} else if (n) {
sptr = sptrs[n];
}
rep(ow, OW) {
float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2];
float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5];
float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8];
float alphaw = numeratorw / denominator;
float alphah = numeratorh / denominator;
int iw0 = get_real_coord(std::floor(alphaw) + 0, IW);
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW);
int ih0 = get_real_coord(std::floor(alphah) + 0, IH);
int ih1 = get_real_coord(std::floor(alphah) + 1, IH);
alphaw -= floor(alphaw);
alphah -= floor(alphah);
if (bmode != BorderMode::CONSTANT) {
rep(c, C) {
visit_dst(c, oh, ow) = output_converter(
visit_src(c, ih0, iw0) * (1.0f - alphaw) * (1.0f - alphah) +
visit_src(c, ih0, iw1) * alphaw * (1.0f - alphah) +
visit_src(c, ih1, iw0) * (1.0f - alphaw) * alphah +
visit_src(c, ih1, iw1) * alphaw * alphah);
}
} else {
rep(c, C) {
auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src_bd(c, ih0, iw1) * alphaw * (1.0f - alphah) +
visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) * alphah +
visit_src_bd(c, ih1, iw1) * alphaw * alphah;
visit_dst(c, oh, ow) =
output_converter(std::isfinite(val) ? val : border_val);
}
}
}
}
MIDOUT_END();
}
#define INST(ctype, mtype) \
template void WarpPerspectiveForwardImpl::kern_naive_multi_src<ctype, mtype>( \
const KernParam<ctype, mtype>&, size_t);
INST(float, float);
#if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16, float);
INST(dt_float16, dt_float16);
INST(dt_bfloat16, float);
INST(dt_bfloat16, dt_bfloat16);
#endif
#undef INST
template <typename ctype, typename mtype> template <typename ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive( void WarpPerspectiveForwardImpl::kern_naive(
const KernParam<ctype, mtype>& kern_param, size_t task_id) { const KernParam<ctype, mtype>& kern_param, size_t task_id) {
...@@ -504,6 +617,71 @@ INST(uint8_t, float, float); ...@@ -504,6 +617,71 @@ INST(uint8_t, float, float);
#undef INST #undef INST
void WarpPerspectiveForwardImpl::exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
TensorLayoutArray srcs_layout;
for (auto&& src : srcs) {
srcs_layout.push_back(src.layout);
}
check_exec_allow_nhwc_mat_idx(
srcs_layout, mat.layout, mat_idx.layout, dst.layout, workspace.size);
size_t batch = dst.layout[0];
#define KERN_NAIVE_MULTI_SRC(ct, mct) \
auto kparam = KernParam<ct, mct>::from_tensors( \
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \
dst, workspace); \
auto run = [kparam, this](size_t index, size_t) { \
kern_naive_multi_src(kparam, index); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch);
#define DISPATCH_ST_MULTI_SRC(dt, ct, mct, kern) \
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
kern(ct, mct); \
return; \
}
#define DISPATCH_ST_MT_MULTI_SRC(dt, ct, kern) \
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
if (mat.layout.dtype.enumv() == DTypeTrait<dtype::Float32>::enumv) { \
kern(ct, float); \
return; \
} else { \
kern(ct, ct); \
return; \
} \
}
megdnn_assert(warp::is_dnn_available(
srcs_layout, mat.layout, dst.layout, param().imode, param().format));
/*!
* We currently use floating point for all WarpPerspective
* computation, so even if the input ctype is one of the integer
* type, mtype should always be float32.
*
* \warning It's different with \c WarpAffine, with mtype be float16
* if input type is float16.
*/
DISPATCH_ST_MULTI_SRC(dtype::Float32, float, float, KERN_NAIVE_MULTI_SRC);
DNN_INC_FLOAT16(
DISPATCH_ST_MT_MULTI_SRC(dtype::Float16, dt_float16, KERN_NAIVE_MULTI_SRC));
DNN_INC_FLOAT16(DISPATCH_ST_MT_MULTI_SRC(
dtype::BFloat16, dt_bfloat16, KERN_NAIVE_MULTI_SRC));
megdnn_throw(ssprintf(
"Unsupported input DType in "
"WarpPerspective: %s",
srcs.front().layout.dtype.name())
.c_str());
#undef KERN_NAIVE_MULTI_SRC
#undef DISPATCH_ST_MT_MULTI_SRC
#undef DISPATCH_ST_MULTI_SRC
}
void WarpPerspectiveForwardImpl::exec( void WarpPerspectiveForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_out dst, _megdnn_workspace workspace) {
......
...@@ -17,8 +17,70 @@ protected: ...@@ -17,8 +17,70 @@ protected:
DType src_dtype, dst_dtype; DType src_dtype, dst_dtype;
RefPtr src_ptr, mat_ptr, dst_ptr; RefPtr src_ptr, mat_ptr, dst_ptr;
RefPtr midx_ptr; //!< can be null RefPtr midx_ptr; //!< can be null
SmallVector<RefPtr> srcs_ptr;
Workspace workspace; Workspace workspace;
static KernParam from_tensors(
Format format, BorderMode bmode, float border_val,
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
auto src = srcs.front();
KernParam ret;
ret.format = format;
ret.bmode = bmode;
ret.border_val = border_val;
ret.n_src = srcs.size();
ret.src_dtype = src.layout.dtype;
ret.dst_dtype = dst.layout.dtype;
if (mat_idx.raw_ptr()) {
megdnn_assert(mat_idx.layout.ndim == 1);
ret.n_mat = mat_idx.layout.shape[0];
ret.midx_ptr = mat_idx.get_ref_ptr();
} else {
megdnn_assert(mat_idx.layout.ndim == 0);
ret.n_mat = ret.n_src;
ret.midx_ptr = nullptr;
}
if (format == Format::NCHW) {
ret.c = src.layout.shape[1];
ret.ih = src.layout.shape[2];
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else {
megdnn_assert(format == Format::NHWC);
ret.c = src.layout.shape[3];
ret.ih = src.layout.shape[1];
ret.iw = src.layout.shape[2];
ret.oh = dst.layout.shape[1];
ret.ow = dst.layout.shape[2];
}
if ((src.layout.dtype.enumv() == DTypeEnum::Float32 ||
DNN_FLOAT16_SELECT(
(src.layout.dtype.enumv() == DTypeEnum::Float16 ||
src.layout.dtype.enumv() == DTypeEnum::BFloat16),
false)) &&
(src.layout.dtype == dst.layout.dtype)) {
for (auto&& s : srcs) {
ret.srcs_ptr.push_back(s.get_ref_ptr());
}
ret.mat_ptr = mat.get_ref_ptr();
ret.dst_ptr = dst.get_ref_ptr();
} else {
for (size_t i = 0; i < srcs.size(); i++) {
ret.srcs_ptr.push_back(nullptr);
}
ret.mat_ptr = nullptr;
ret.dst_ptr = nullptr;
}
ret.src_ptr = nullptr;
ret.workspace = workspace;
return ret;
}
static KernParam from_tensors( static KernParam from_tensors(
Format format, BorderMode bmode, float border_val, Format format, BorderMode bmode, float border_val,
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
...@@ -124,16 +186,29 @@ protected: ...@@ -124,16 +186,29 @@ protected:
template <typename ctype, typename mtype> template <typename ctype, typename mtype>
void kern_naive(const KernParam<ctype, mtype>& kern_param, size_t task_id); void kern_naive(const KernParam<ctype, mtype>& kern_param, size_t task_id);
template <typename ctype, typename mtype>
void kern_naive_multi_src(
const KernParam<ctype, mtype>& kern_param, size_t task_id);
public: public:
using WarpPerspectiveForward::WarpPerspectiveForward; using WarpPerspectiveForward::WarpPerspectiveForward;
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
void exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
} }
size_t get_workspace_in_bytes(
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
private: private:
template <typename ctype, typename mtype> template <typename ctype, typename mtype>
...@@ -253,6 +328,10 @@ private: ...@@ -253,6 +328,10 @@ private:
auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \ auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \
auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \ auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \
auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \ auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \
SmallVector<const ctype*> sptrs; \
for (auto&& s_ptr : p.srcs_ptr) { \
sptrs.push_back(static_cast<const ctype*>(s_ptr.get_ptr())); \
} \
auto bmode = p.bmode; \ auto bmode = p.bmode; \
float border_val = p.border_val float border_val = p.border_val
......
...@@ -50,6 +50,54 @@ void WarpPerspectiveMatIdxProxy::exec( ...@@ -50,6 +50,54 @@ void WarpPerspectiveMatIdxProxy::exec(
tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], W.workspace()); tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], W.workspace());
} }
void WarpPerspectiveMultiSrcProxy::deduce_layout(
WarpPerspectiveForward*, TensorLayoutArray&) {}
void WarpPerspectiveMultiSrcProxy::exec(
WarpPerspectiveForward* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
megdnn_assert(tensors.size() >= 3);
bool has_mat_idx = false;
TensorLayout mat_idx_layout;
TensorND mat_idx_tensor;
TensorLayoutArray layouts(tensors.size());
std::transform(
tensors.begin(), tensors.end(), layouts.begin(),
[](const TensorND& tensor) { return tensor.layout; });
auto srcs_layouts = layouts;
srcs_layouts.pop_back(); // dst
if (srcs_layouts.back().ndim == 1) {
has_mat_idx = true;
mat_idx_layout = srcs_layouts.back();
srcs_layouts.pop_back(); // mat_idx;
}
auto mat_layout = srcs_layouts.back();
srcs_layouts.pop_back(); // mat
if (has_mat_idx)
W.update(opr->get_workspace_in_bytes(
srcs_layouts, mat_layout, mat_idx_layout, layouts.back()));
else
W.update(opr->get_workspace_in_bytes(srcs_layouts, mat_layout, layouts.back()));
auto srcs_tensors = tensors;
srcs_tensors.pop_back(); // dst
if (has_mat_idx) {
mat_idx_tensor = srcs_tensors.back();
srcs_tensors.pop_back(); // mat_idx;
}
auto mat_tensor = srcs_tensors.back();
srcs_tensors.pop_back(); // mat
if (has_mat_idx)
opr->exec(
srcs_tensors, mat_tensor, mat_idx_tensor, tensors.back(),
W.workspace());
else
opr->exec(srcs_tensors, mat_tensor, tensors.back(), W.workspace());
}
std::vector<TestArg> warp_perspective::get_cv_args() { std::vector<TestArg> warp_perspective::get_cv_args() {
std::vector<TestArg> args; std::vector<TestArg> args;
......
...@@ -19,6 +19,12 @@ struct WarpPerspectiveMatIdxProxy { ...@@ -19,6 +19,12 @@ struct WarpPerspectiveMatIdxProxy {
void exec(WarpPerspectiveBackwardMat* opr, const TensorNDArray& tensors); void exec(WarpPerspectiveBackwardMat* opr, const TensorNDArray& tensors);
}; };
struct WarpPerspectiveMultiSrcProxy {
WorkspaceWrapper W;
static void deduce_layout(WarpPerspectiveForward*, TensorLayoutArray&);
void exec(WarpPerspectiveForward* opr, const TensorNDArray& tensors);
};
class WarpPerspectiveMatRNG final : public IIDRNG { class WarpPerspectiveMatRNG final : public IIDRNG {
public: public:
WarpPerspectiveMatRNG() : idx(0) {} WarpPerspectiveMatRNG() : idx(0) {}
......
...@@ -887,6 +887,194 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) { ...@@ -887,6 +887,194 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) {
} }
} }
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_NCHW) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
auto run = [&param, &rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle_cuda());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, c, ih, iw}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{bs, 3, 3}});
checker.set_rng(bs, &rng);
// dst
shapes.emplace_back(TensorShape{{bs, c, oh, ow}});
checker.set_dtype(bs + 1, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, dtype);
run(2, 100, 110, 10, 50, 50, dtype);
run(20, 10, 11, 123, 15, 16, dtype);
run(2200, 10, 11, 3, 11, 12, dtype);
}
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_NHWC) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
auto run = [&param, &rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle_cuda());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, ih, iw, c}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{bs, 3, 3}});
checker.set_rng(bs, &rng);
// dst
shapes.emplace_back(TensorShape{{bs, oh, ow, c}});
checker.set_dtype(bs + 1, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, dtype);
run(2, 100, 110, 10, 50, 50, dtype);
run(20, 10, 11, 123, 15, 16, dtype);
run(2200, 10, 11, 3, 11, 12, dtype);
}
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
UniformIntRNG idx_rng{0, 0};
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
auto run = [&param, &rng, &idx_rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, size_t idx, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle_cuda());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, c, ih, iw}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{idx, 3, 3}});
checker.set_rng(bs, &rng);
// mat_idx
shapes.emplace_back(TensorShape{{idx}});
checker.set_dtype(bs + 1, dtype::Int32());
idx_rng = UniformIntRNG{0, (int)bs - 1};
checker.set_rng(bs + 1, &idx_rng);
// dst
shapes.emplace_back(TensorShape{{idx, c, oh, ow}});
checker.set_dtype(bs + 2, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, 1, dtype);
run(2, 100, 110, 10, 50, 50, 1, dtype);
run(20, 10, 11, 123, 15, 16, 10, dtype);
run(2200, 10, 11, 3, 11, 12, 100, dtype);
}
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
UniformIntRNG idx_rng{0, 0};
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
auto run = [&param, &rng, &idx_rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, size_t idx, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle_cuda());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, ih, iw, c}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{idx, 3, 3}});
checker.set_rng(bs, &rng);
// mat_idx
shapes.emplace_back(TensorShape{{idx}});
checker.set_dtype(bs + 1, dtype::Int32());
idx_rng = UniformIntRNG{0, (int)bs - 1};
checker.set_rng(bs + 1, &idx_rng);
// dst
shapes.emplace_back(TensorShape{{idx, oh, ow, c}});
checker.set_dtype(bs + 2, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, 1, dtype);
run(2, 100, 110, 10, 50, 50, 1, dtype);
run(20, 10, 11, 123, 15, 16, 10, dtype);
run(2200, 10, 11, 3, 11, 12, 100, dtype);
}
}
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) {
......
...@@ -172,6 +172,190 @@ TEST_F(FALLBACK, WARP_PERSPECTIFVE_NCHW_QUINT8) { ...@@ -172,6 +172,190 @@ TEST_F(FALLBACK, WARP_PERSPECTIFVE_NCHW_QUINT8) {
warp_perspective::run_quint8_test(handle()); warp_perspective::run_quint8_test(handle());
} }
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_NCHW) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
auto run = [&param, &rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, c, ih, iw}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{bs, 3, 3}});
checker.set_rng(bs, &rng);
// dst
shapes.emplace_back(TensorShape{{bs, c, oh, ow}});
checker.set_dtype(bs + 1, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, dtype);
run(20, 10, 11, 123, 15, 16, dtype);
run(100, 10, 11, 3, 11, 12, dtype);
}
}
}
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_NHWC) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
auto run = [&param, &rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, ih, iw, c}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{bs, 3, 3}});
checker.set_rng(bs, &rng);
// dst
shapes.emplace_back(TensorShape{{bs, oh, ow, c}});
checker.set_dtype(bs + 1, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, dtype);
run(20, 10, 11, 123, 15, 16, dtype);
run(100, 10, 11, 3, 11, 12, dtype);
}
}
}
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
UniformIntRNG idx_rng{0, 0};
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
auto run = [&param, &rng, &idx_rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, size_t idx, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, c, ih, iw}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{idx, 3, 3}});
checker.set_rng(bs, &rng);
// mat_idx
shapes.emplace_back(TensorShape{{idx}});
checker.set_dtype(bs + 1, dtype::Int32());
idx_rng = UniformIntRNG{0, (int)bs - 1};
checker.set_rng(bs + 1, &idx_rng);
// dst
shapes.emplace_back(TensorShape{{idx, c, oh, ow}});
checker.set_dtype(bs + 2, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, 1, dtype);
run(20, 10, 11, 123, 15, 16, 10, dtype);
run(100, 10, 11, 3, 11, 12, 100, dtype);
}
}
}
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) {
using Param = WarpPerspective::Param;
Param param;
WarpPerspectiveMatRNG rng;
UniformIntRNG idx_rng{0, 0};
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
auto run = [&param, &rng, &idx_rng, this](
size_t bs, size_t ih, size_t iw, size_t c, size_t oh,
size_t ow, size_t idx, DType dtype) {
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker(
handle());
checker.set_param(param);
TensorShapeArray shapes;
// src
for (size_t i = 0; i < bs; i++) {
shapes.emplace_back(TensorShape{{1, ih, iw, c}});
checker.set_dtype(i, dtype);
}
// mat
shapes.emplace_back(TensorShape{{idx, 3, 3}});
checker.set_rng(bs, &rng);
// mat_idx
shapes.emplace_back(TensorShape{{idx}});
checker.set_dtype(bs + 1, dtype::Int32());
idx_rng = UniformIntRNG{0, (int)bs - 1};
checker.set_rng(bs + 1, &idx_rng);
// dst
shapes.emplace_back(TensorShape{{idx, oh, ow, c}});
checker.set_dtype(bs + 2, dtype);
checker.execs(shapes);
};
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(1, 20, 18, 4, 6, 6, 1, dtype);
run(20, 10, 11, 123, 15, 16, 10, dtype);
run(100, 10, 11, 3, 11, 12, 100, dtype);
}
}
}
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
...@@ -55,6 +55,282 @@ class NanMatRNG : public RNG { ...@@ -55,6 +55,282 @@ class NanMatRNG : public RNG {
}; };
} // namespace } // namespace
TEST_F(NAIVE, WARP_PERSPECTIVE_MULTI_SRC) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
auto extra_impl = [&param, this](const TensorNDArray& tensors) {
//! split src
TensorND src = tensors[0]; // n h w c
size_t n = src.layout[0];
TensorNDArray srcs; // n 个 1 h w c
TensorLayoutArray srcs_layouts;
for (size_t i = 0; i < n; i++) {
TensorLayout ly;
ly = TensorLayout{
{1, src.layout[1], src.layout[2], src.layout[3]}, src.layout.dtype};
srcs.emplace_back(malloc(ly.span().dist_byte()), ly);
srcs_layouts.emplace_back(ly);
}
auto split = handle()->create_operator<SplitForward>();
split->param().axis = 0;
auto split_ws_size = split->get_workspace_in_bytes(src.layout, srcs_layouts);
dt_byte* split_ws_ptr = static_cast<dt_byte*>(malloc(split_ws_size));
Workspace split_ws{split_ws_ptr, split_ws_size};
split->exec(src, srcs, split_ws);
auto warp_perspective = handle()->create_operator<WarpPerspective>();
warp_perspective->param() = param;
auto warp_ws_size = warp_perspective->get_workspace_in_bytes(
srcs_layouts, tensors[1].layout, tensors[2].layout);
dt_byte* warp_ws_ptr = static_cast<dt_byte*>(malloc(warp_ws_size));
Workspace warp_ws{warp_ws_ptr, warp_ws_size};
warp_perspective->exec(srcs, tensors[1], tensors[2], warp_ws);
free(split_ws_ptr);
free(warp_ws_ptr);
for (auto&& s : srcs) {
free(s.raw_ptr());
}
};
{
// Float32
Checker<WarpPerspectiveForward> checker(handle());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_extra_opr_impl(extra_impl);
// NHWC
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
checker.set_param(param);
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}});
checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}});
checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}});
checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}});
}
// NCHW
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
checker.set_param(param);
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1, 4, 2, 2}});
checker.execs({{2, 4, 10, 10}, {2, 3, 3}, {2, 4, 10, 12}});
checker.execs({{3, 8, 25, 24}, {3, 3, 3}, {3, 8, 12, 10}});
checker.execs({{4, 16, 33, 22}, {4, 3, 3}, {4, 16, 9, 12}});
}
}
{
// Float16
Checker<WarpPerspectiveForward> checker(handle());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(2, dtype::Float16());
checker.set_extra_opr_impl(extra_impl);
// NHWC
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
checker.set_param(param);
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}});
checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}});
checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}});
checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}});
}
// NCHW
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
checker.set_param(param);
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1, 4, 2, 2}});
checker.execs({{2, 4, 10, 10}, {2, 3, 3}, {2, 4, 10, 12}});
checker.execs({{3, 8, 25, 24}, {3, 3, 3}, {3, 8, 12, 10}});
checker.execs({{4, 16, 33, 22}, {4, 3, 3}, {4, 16, 9, 12}});
}
}
}
TEST_F(NAIVE, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
auto extra_impl = [&param, this](const TensorNDArray& tensors) {
//! split src
TensorND src = tensors[0]; // n h w c
size_t n = src.layout[0];
TensorNDArray srcs; // n 个 1 h w c
TensorLayoutArray srcs_layouts;
for (size_t i = 0; i < n; i++) {
TensorLayout ly;
ly = TensorLayout{
{1, src.layout[1], src.layout[2], src.layout[3]}, src.layout.dtype};
srcs.emplace_back(malloc(ly.span().dist_byte()), ly);
srcs_layouts.emplace_back(ly);
}
auto split = handle()->create_operator<SplitForward>();
split->param().axis = 0;
auto split_ws_size = split->get_workspace_in_bytes(src.layout, srcs_layouts);
dt_byte* split_ws_ptr = static_cast<dt_byte*>(malloc(split_ws_size));
Workspace split_ws{split_ws_ptr, split_ws_size};
split->exec(src, srcs, split_ws);
auto warp_perspective = handle()->create_operator<WarpPerspective>();
warp_perspective->param() = param;
auto warp_ws_size = warp_perspective->get_workspace_in_bytes(
srcs_layouts, tensors[1].layout, tensors[2].layout, tensors[3].layout);
dt_byte* warp_ws_ptr = static_cast<dt_byte*>(malloc(warp_ws_size));
Workspace warp_ws{warp_ws_ptr, warp_ws_size};
warp_perspective->exec(srcs, tensors[1], tensors[2], tensors[3], warp_ws);
free(split_ws_ptr);
free(warp_ws_ptr);
for (auto&& s : srcs) {
free(s.raw_ptr());
}
};
{
// Float32
Checker<WarpPerspectiveForward, WarpPerspectiveMatIdxProxy> checker(handle());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Int32());
checker.set_dtype(3, dtype::Float32());
checker.set_extra_opr_impl(extra_impl);
// NHWC
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
checker.set_param(param);
UniformIntRNG idx_rng{0, 0};
checker.set_rng(2, &idx_rng);
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1}, {1, 2, 2, 4}});
idx_rng = UniformIntRNG{0, 1};
checker.set_rng(2, &idx_rng);
checker.execs({{2, 10, 10, 4}, {1, 3, 3}, {1}, {1, 10, 12, 4}});
idx_rng = UniformIntRNG{0, 2};
checker.set_rng(2, &idx_rng);
checker.execs({{3, 25, 24, 8}, {2, 3, 3}, {2}, {2, 12, 10, 8}});
checker.execs({{4, 33, 22, 16}, {2, 3, 3}, {2}, {2, 9, 12, 16}});
}
// NCHW
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
checker.set_param(param);
UniformIntRNG idx_rng{0, 0};
checker.set_rng(2, &idx_rng);
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1}, {1, 4, 2, 2}});
idx_rng = UniformIntRNG{0, 1};
checker.set_rng(2, &idx_rng);
checker.execs({{2, 4, 10, 10}, {1, 3, 3}, {1}, {1, 4, 10, 12}});
idx_rng = UniformIntRNG{0, 2};
checker.set_rng(2, &idx_rng);
checker.execs({{3, 8, 25, 24}, {2, 3, 3}, {2}, {2, 8, 12, 10}});
checker.execs({{4, 16, 33, 22}, {2, 3, 3}, {2}, {2, 16, 9, 12}});
}
}
{
// Float16
Checker<WarpPerspectiveForward, WarpPerspectiveMatIdxProxy> checker(handle());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Int32());
checker.set_dtype(3, dtype::Float16());
checker.set_extra_opr_impl(extra_impl);
// NHWC
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NHWC;
checker.set_param(param);
UniformIntRNG idx_rng{0, 0};
checker.set_rng(2, &idx_rng);
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1}, {1, 2, 2, 4}});
idx_rng = UniformIntRNG{0, 1};
checker.set_rng(2, &idx_rng);
checker.execs({{2, 10, 10, 4}, {1, 3, 3}, {1}, {1, 10, 12, 4}});
idx_rng = UniformIntRNG{0, 2};
checker.set_rng(2, &idx_rng);
checker.execs({{3, 25, 24, 8}, {2, 3, 3}, {2}, {2, 12, 10, 8}});
checker.execs({{4, 33, 22, 16}, {2, 3, 3}, {2}, {2, 9, 12, 16}});
}
// NCHW
for (auto bmode :
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
checker.set_param(param);
UniformIntRNG idx_rng{0, 0};
checker.set_rng(2, &idx_rng);
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1}, {1, 4, 2, 2}});
idx_rng = UniformIntRNG{0, 1};
checker.set_rng(2, &idx_rng);
checker.execs({{2, 4, 10, 10}, {1, 3, 3}, {1}, {1, 4, 10, 12}});
idx_rng = UniformIntRNG{0, 2};
checker.set_rng(2, &idx_rng);
checker.execs({{3, 8, 25, 24}, {2, 3, 3}, {2}, {2, 8, 12, 10}});
checker.execs({{4, 16, 33, 22}, {2, 3, 3}, {2}, {2, 16, 9, 12}});
}
}
}
TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4) { TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4) {
using Param = WarpPerspective::Param; using Param = WarpPerspective::Param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册