提交 16324e30 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add remap backward

GitOrigin-RevId: 1b1bcf5db3312a4ea15753ecf770f28791b22a1b
上级 46b68568
......@@ -270,6 +270,41 @@ protected:
};
using Remap = RemapForward;
class RemapBackwardData : public RemapBase {
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
public:
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
protected:
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
};
class RemapBackwardMat : public RemapBase {
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
};
class SeparableFilterBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
DEF_OPR_PARAM(SeparableFilter);
......
......@@ -197,6 +197,8 @@ private:
cb(ROIAlignBackward) \
cb(BatchConvBiasForward) \
cb(Remap) \
cb(RemapBackwardData) \
cb(RemapBackwardMat) \
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
......
......@@ -50,6 +50,7 @@ void RemapBase::check_layout_fwd(const TensorLayout& src,
megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str());
megdnn_assert(map_xy.shape[3] == 2);
megdnn_assert(map_xy.shape[0] == src.shape[0]);
megdnn_assert_contiguous(src);
// map_xy only support floa32 type
// map_xy always in NHWC format
......@@ -85,6 +86,34 @@ void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void RemapBackwardData::check_exec(const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(grad, map_xy, diff);
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes =
get_workspace_in_bytes(map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void RemapBackwardMat::check_exec(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_in_bytes) {
check_layout_fwd(src, map_xy, diff);
megdnn_assert_eq_layout(map_xy, grad);
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/remap/backward_data.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/remap/common.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");
megdnn_assert(param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
auto stream = cuda_stream(this->handle());
int N, C, IH, IW, OH, OW;
N = grad.layout.shape[0];
C = grad.layout.shape[1];
IH = grad.layout.shape[2];
IW = grad.layout.shape[3];
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
#define cb(dt, _format, bmode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::backwarddata_proxy<ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
grad.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \
break; \
}
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
megdnn_throw("unsupported border type in remap cuda"); \
}
switch (grad.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}
#undef support_dtype
#undef cb
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/remap/backward_data.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <cuda_runtime.h>
#include "src/common/rounding_converter.cuh"
#include "src/cuda/cv/kernel_common.cuh"
#include "src/cuda/remap/common.h"
#include "src/cuda/utils.cuh"
using namespace megdnn;
using namespace cuda;
using namespace remap;
using namespace rounding;
namespace {
template <const uint32_t format>
__device__ inline int get_offset(int height, int width, int channel, int h,
int w, int c);
template <>
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>(
int height, int width, int channel, int h, int w, int c) {
return channel * h * w + height * w + width;
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
struct GetSrcData {
__device__ static inline int get_index(int height, int width, int channel,
int h, int w, int c) {
height = megcv::border_interpolate<bmode>(height, h);
width = megcv::border_interpolate<bmode>(width, w);
return get_offset<format>(height, width, channel, h, w, c);
}
};
template <typename ctype, const uint32_t format>
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
__device__ static inline int get_index(int height, int width, int channel,
int h, int w, int c) {
return (height >= 0 && height < h && width >= 0 && width < w)
? get_offset<format>(height, width, channel, h, w, c)
: -1;
}
};
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general(ctype* __restrict grad, const float* map_xy,
const ctype* diff, int C, int IH, int IW, int OH,
int OW) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
grad += blockIdx.z * C * IH * IW;
diff += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col; // alphah
float u = index_row - row; // alphaw
const float one = 1.f;
for (int c = 0; c < C; ++c) {
float hidden = static_cast<float>(
diff[get_offset<format>(oh, ow, c, OH, OW, C)]);
int a00 = GetSrcData<ctype, format, bmode>::get_index(
row + 0, col + 0, c, IH, IW, C);
if (a00 != -1) {
atomic_add(grad + a00,
round_converter((one - u) * (one - v) * hidden));
}
int a01 = GetSrcData<ctype, format, bmode>::get_index(
row + 0, col + 1, c, IH, IW, C);
if (a01 != -1) {
atomic_add(grad + a01, round_converter((one - u) * v * hidden));
}
int a10 = GetSrcData<ctype, format, bmode>::get_index(
row + 1, col + 0, c, IH, IW, C);
if (a10 != -1) {
atomic_add(grad + a10, round_converter(u * (one - v) * hidden));
}
int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get_index(row + 1, col + 1, c, IH, IW,
C);
if (a11 != -1) {
atomic_add(grad + a11, round_converter(u * v * hidden));
}
}
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void dispatch_backwarddata(ctype* grad, const float* map_xy, const ctype* diff,
int N, int C, int IH, int IW, int OH, int OW,
cudaStream_t stream) {
const int BX = 32, BY = 16;
const int max_batch_size = 65535;
while (N) {
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW,
stream));
kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
grad, map_xy, diff, C, IH, IW, OH, OW);
N -= curr_batch_size;
grad += curr_batch_size * C * IH * IW;
diff += curr_batch_size * C * OH * OW;
map_xy += curr_batch_size * 2 * OH * OW;
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff,
int N, int C, int IH, int IW, int OH, int OW,
cudaStream_t stream) {
dispatch_backwarddata<ctype, format, bmode>(grad, map_xy, diff, N, C, IH,
IW, OH, OW, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
template void backwarddata_proxy< \
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
ctype*, const float*, const ctype*, int, int, int, int, int, int, \
cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP)
FOR_FORMAT_BMODE(float)
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
#undef FOR_FORMAT_BMODE
#undef INST
} // namespace remap
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/remap/backward_mat.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/remap/common.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, diff.layout, grad.layout,
workspace.size);
megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");
megdnn_assert(param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
auto stream = cuda_stream(this->handle());
int N, C, IH, IW, OH, OW;
N = src.layout.shape[0];
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
#define cb(dt, _format, bmode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::backwardmat_proxy<ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
src.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), \
grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \
param().scalar, stream); \
break; \
}
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
megdnn_throw("unsupported border type in remap cuda"); \
}
switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}
#undef support_dtype
#undef cb
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/remap/backward_mat.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <cuda_runtime.h>
#include "src/common/rounding_converter.cuh"
#include "src/cuda/cv/kernel_common.cuh"
#include "src/cuda/remap/common.h"
#include "src/cuda/utils.cuh"
using namespace megdnn;
using namespace cuda;
using namespace remap;
using namespace rounding;
namespace {
template <const uint32_t format>
__device__ inline int get_offset(int height, int width, int channel, int h,
int w, int c);
template <>
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>(
int height, int width, int channel, int h, int w, int c) {
return channel * h * w + height * w + width;
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
struct GetSrcData {
__device__ static inline int get_index(int height, int width, int channel,
int h, int w, int c) {
height = megcv::border_interpolate<bmode>(height, h);
width = megcv::border_interpolate<bmode>(width, w);
return get_offset<format>(height, width, channel, h, w, c);
}
};
template <typename ctype, const uint32_t format>
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
__device__ static inline int get_index(int height, int width, int channel,
int h, int w, int c) {
return (height >= 0 && height < h && width >= 0 && width < w)
? get_offset<format>(height, width, channel, h, w, c)
: -1;
}
};
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general(const ctype* src, const float* map_xy,
const ctype* diff, float* __restrict grad, int C,
int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
src += blockIdx.z * C * IH * IW;
diff += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
grad += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col; // alphaw
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
float hidden = static_cast<float>(
diff[get_offset<format>(
oh, ow, c, OH, OW, C)]);
float du = 0.f, dv = 0.f;
int a00 = GetSrcData<ctype, format, bmode>::get_index(
row + 0, col + 0, c, IH, IW, C);
int a01 = GetSrcData<ctype, format, bmode>::get_index(
row + 0, col + 1, c, IH, IW, C);
int a10 = GetSrcData<ctype, format, bmode>::get_index(
row + 1, col + 0, c, IH, IW, C);
int a11 = GetSrcData<ctype, format, bmode>::get_index(
row + 1, col + 1, c, IH, IW, C);
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
dv -= ((a10 != -1) ? src[a10] : scalar) * u;
dv += ((a11 != -1) ? src[a11] : scalar) * u;
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
du -= ((a01 != -1) ? src[a01] : scalar) * v;
du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
du += ((a11 != -1) ? src[a11] : scalar) * v;
grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv);
grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du);
}
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void dispatch_backwardmat(const ctype* src, const float* map_xy,
const ctype* diff, float* grad, int N, int C, int IH,
int IW, int OH, int OW, float scalar,
cudaStream_t stream) {
const int BX = 32, BY = 16;
const int max_batch_size = 65535;
while (N) {
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2,
stream));
kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
N -= curr_batch_size;
src += curr_batch_size * C * IH * IW;
diff += curr_batch_size * C * OH * OW;
map_xy += curr_batch_size * 2 * OH * OW;
grad += curr_batch_size * 2 * OH * OW;
}
}
} // anonymous namespace
namespace megdnn {
namespace cuda {
namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, float scalar, cudaStream_t stream) {
dispatch_backwardmat<ctype, format, bmode>(src, map_xy, diff, grad, N, C,
IH, IW, OH, OW, scalar, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
template void backwardmat_proxy<ctype, param_enumv::Remap::Format::format, \
::BorderMode::bmode>( \
const ctype*, const float*, const ctype*, float*, int, int, int, \
int, int, int, float, cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP)
FOR_FORMAT_BMODE(float)
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
#undef FOR_FORMAT_BMODE
#undef INST
} // namespace remap
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -24,7 +24,17 @@ namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream);
cudaStream_t stream);
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff,
int N, int C, int IH, int IW, int OH, int OW,
cudaStream_t stream);
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, float scalar, cudaStream_t stream);
} // namespace remap
} // namespace cuda
......
......@@ -22,9 +22,10 @@ using namespace cuda;
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
_megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
megdnn_assert(map_xy.layout.dtype.enumv() ==
DTypeTrait<dtype::Float32>::enumv);
auto stream = cuda_stream(this->handle());
int N, C, IH, IW, OH, OW;
ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
......@@ -36,10 +37,6 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
S_IN = src.layout.stride[0];
S_IC = src.layout.stride[1];
S_IH = src.layout.stride[2];
S_IW = src.layout.stride[3];
} else if (param().format == param::Remap::Format::NHWC) {
N = src.layout.shape[0];
C = src.layout.shape[3];
......@@ -58,7 +55,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
src.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \
param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \
param().scalar, stream); \
break; \
}
......@@ -78,15 +75,16 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
}
switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32)
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16))
support_dtype(dtype::Int8)
support_dtype(dtype::Uint8)
support_dtype(dtype::Float32);
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16));
MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
support_dtype(dtype::Int8);
support_dtype(dtype::Uint8);
default:
megdnn_throw("unsupported dtype in remap cuda");
}
#undef supported_dtype
#undef support_dtype
#undef cb
}
......
......@@ -23,17 +23,6 @@ using namespace rounding;
namespace {
template <typename ctype>
struct DirectSrcVisitor {
const ctype* ptr;
__device__ __forceinline__ const ctype* get(int batch, int im_size) {
return ptr + batch * im_size;
}
void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; }
};
template <const uint32_t format>
__device__ inline int get_offset(int height, int width, int channel, int h,
int w, int c);
......@@ -74,14 +63,13 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
}
};
template <typename ctype, typename SrcVisitor, ::BorderMode bmode>
__global__ void kern_general(SrcVisitor src, const float* map_xy,
template <typename ctype, ::BorderMode bmode>
__global__ void kern_general(const ctype* __restrict sptr, const float* map_xy,
ctype* __restrict dst, int C, int IH, int IW,
int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, float scalar) {
int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
......@@ -89,8 +77,8 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy,
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = (int)floor(index_col);
int row = (int)floor(index_row);
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
......@@ -106,22 +94,25 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy,
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW,
C)] =
round_converter(a00 * (1.f - u) * (1.f - v) +
a01 * (1.f - u) * v + a10 * (1.f - v) * u +
a11 * u * v);
/* in remap, we use float as the type of intermediate result */
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) +
static_cast<float>(a01) * (1.f - u) * v +
static_cast<float>(a10) * (1.f - v) * u +
static_cast<float>(a11) * u * v;
dst[get_offset<param_enumv::Remap::Format::NCHW>(
oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
template <typename ctype, typename SrcVisitor, ::BorderMode bmode>
__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
ctype* __restrict dst, int C, int IH, int IW,
int OH, int OW, float scalar) {
template <typename ctype, ::BorderMode bmode>
__global__ void kern_general_nhwc(const ctype* __restrict sptr,
const float* map_xy, ctype* __restrict dst,
int C, int IH, int IW, int OH, int OW,
float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW);
sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
......@@ -129,8 +120,8 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = (int)floor(index_col);
int row = (int)floor(index_row);
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
......@@ -146,21 +137,21 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC,
bmode>::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW,
C)] =
round_converter(a00 * (1.f - u) * (1.f - v) +
a01 * (1.f - u) * v + a10 * (1.f - v) * u +
a11 * u * v);
/* in remap, we use float as the type of intermediate result */
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) +
static_cast<float>(a01) * (1.f - u) * v +
static_cast<float>(a10) * (1.f - v) * u +
static_cast<float>(a11) * u * v;
dst[get_offset<param_enumv::Remap::Format::NHWC>(
oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
template <typename ctype, typename SrcVisitor, const uint32_t format,
::BorderMode bmode>
void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst,
int N, int C, int IH, int IW, int OH, int OW,
float scalar, int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void dispatch_forward(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
cudaStream_t stream) {
const int BX = 32, BY = 16;
const int max_batch_size = 65535;
......@@ -170,19 +161,17 @@ void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst,
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
if (format == param_enumv::Remap::Format::NCHW) {
kern_general<ctype, SrcVisitor, bmode>
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH,
IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scalar);
kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
} else if (format == param_enumv::Remap::Format::NHWC) {
kern_general_nhwc<ctype, SrcVisitor, bmode>
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH,
IW, OH, OW, scalar);
kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
}
N -= curr_batch_size;
src.move_batch(curr_batch_size, C * IH * IW);
src += curr_batch_size * C * IH * IW;
dst += curr_batch_size * C * OH * OW;
map_xy += curr_batch_size * OH * OW * 2;
}
}
......@@ -195,22 +184,17 @@ namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
DirectSrcVisitor<ctype> visitor;
visitor.ptr = src;
using SrcVisitor = DirectSrcVisitor<ctype>;
dispatch_with_visitor<ctype, SrcVisitor, format, bmode>(
visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC,
S_IH, S_IW, stream);
dispatch_forward<ctype, format, bmode>(src, map_xy, dst, N, C, IH, IW, OH,
OW, scalar, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
template void forward_proxy<ctype, param_enumv::Remap::Format::format, \
::BorderMode::bmode>( \
const ctype* src, const float*, ctype*, int, int, int, int, int, \
int, float, int, int, int, int, cudaStream_t);
#define INST(ctype, format, bmode) \
template void forward_proxy<ctype, param_enumv::Remap::Format::format, \
::BorderMode::bmode>( \
const ctype*, const float*, ctype*, int, int, int, int, int, int, \
float, cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
......@@ -226,11 +210,13 @@ void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
FOR_FORMAT_BMODE(float)
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
FOR_FORMAT_BMODE(int8_t)
FOR_FORMAT_BMODE(uint8_t)
#undef FOR_BMODE
#undef FOR_FORMAT_BMODE
#undef INST
} // namespace remap
} // namespace cuda
} // namespace megdnn
......
......@@ -15,13 +15,41 @@
namespace megdnn {
namespace cuda {
class RemapImpl final : public Remap {
public:
using Remap::Remap;
void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out,
_megdnn_workspace) override;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& dst) override {
return 0;
}
};
class RemapBackwardDataImpl final : public RemapBackwardData {
public:
using RemapBackwardData::RemapBackwardData;
void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) override {
return 0;
}
};
class RemapBackwardMatImpl final : public RemapBackwardMat {
public:
using RemapBackwardMat::RemapBackwardMat;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) override {
return 0;
}
};
......
......@@ -12,11 +12,13 @@
#include "src/naive/remap/opr_impl.h"
#include "src/common/cv/helper.h"
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
using namespace rounding;
namespace {
template <param::Remap::Format format>
......@@ -36,35 +38,46 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width,
return height * w * c + width * c + channel;
}
template <typename DataType, param::Remap::Format format,
template <typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
struct GetSrcData {
static inline DataType get(const DataType* src, int height, int width,
int channel, int h, int w, int c, float,
std::function<DataType(float)>) {
static inline ctype get(const ctype* src, int height, int width,
int channel, int h, int w, int c, float) {
height = megcv::border_interpolate<bordertype>(height, h);
width = megcv::border_interpolate<bordertype>(width, w);
return src[get_offset<format>(height, width, channel, h, w, c)];
}
static inline int get_index(int height, int width, int channel, int h,
int w, int c) {
height = megcv::border_interpolate<bordertype>(height, h);
width = megcv::border_interpolate<bordertype>(width, w);
return get_offset<format>(height, width, channel, h, w, c);
}
};
template <typename DataType, param::Remap::Format format>
struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> {
static inline DataType get(const DataType* src, int height, int width,
int channel, int h, int w, int c, float scalar,
std::function<DataType(float)> round) {
template <typename ctype, param::Remap::Format format>
struct GetSrcData<ctype, format, param::Remap::BorderMode::CONSTANT> {
static inline ctype get(const ctype* src, int height, int width,
int channel, int h, int w, int c, float scalar) {
RoundingConverter<ctype> round;
return (height >= 0 && height < h && width >= 0 && width < w)
? src[get_offset<format>(height, width, channel, h, w,
c)]
: static_cast<DataType>(round(scalar));
: round(scalar);
}
static inline int get_index(int height, int width, int channel, int h,
int w, int c) {
return (height >= 0 && height < h && width >= 0 && width < w)
? get_offset<format>(height, width, channel, h, w, c)
: -1;
}
};
template <typename DataType, param::Remap::Format format,
template <typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
int N, int C, int IH, int IW, int OH, int OW, float scalar,
std::function<DataType(float)> round) {
void remap_LINEAR(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar) {
RoundingConverter<ctype> round_converter;
for (int n = 0; n < N;
++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
......@@ -73,47 +86,131 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col;
float u = index_row - row;
float one = 1.f;
float v = index_col - col; // alphaw
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
DataType a00 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 0, col + 0, c, IH, IW, C, scalar,
round);
DataType a01 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 0, col + 1, c, IH, IW, C, scalar,
round);
DataType a10 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 1, col + 0, c, IH, IW, C, scalar,
round);
DataType a11 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 1, col + 1, c, IH, IW, C, scalar,
round);
ctype a00 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 0, c, IH, IW, C, scalar);
ctype a01 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 1, c, IH, IW, C, scalar);
ctype a10 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 0, c, IH, IW, C, scalar);
ctype a11 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 1, c, IH, IW, C, scalar);
dst[get_offset<format>(h, w, c, OH, OW, C)] =
static_cast<DataType>(
round(a00 * (one - u) * (one - v) +
a01 * (one - u) * v +
a10 * (one - v) * u + a11 * u * v));
round_converter(a00 * (one - v) * (one - u) +
a01 * (one - u) * v +
a10 * (one - v) * u + a11 * u * v);
}
}
}
}
}
template <typename DataType, DTypeCategory cat>
struct Round {
static inline DataType round(float x) { return std::round(x); }
};
template <typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR_backwarddata(ctype* grad, const float* map_xy,
const ctype* diff, int N, int C, int IH, int IW,
int OH, int OW) {
RoundingConverter<ctype> round_converter;
std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
for (int n = 0; n < N;
++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col; // alphaw
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)];
template <typename DataType>
struct Round<DataType, DTypeCategory::FLOAT> {
static inline DataType round(float x) { return static_cast<DataType>(x); }
};
int a00 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 0, c, IH, IW, C);
if (a00 != -1) {
grad[a00] +=
round_converter((one - v) * (one - u) * hidden);
}
int a01 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 1, c, IH, IW, C);
if (a01 != -1) {
grad[a01] += round_converter((one - u) * v * hidden);
}
int a10 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 0, c, IH, IW, C);
if (a10 != -1) {
grad[a10] += round_converter(u * (one - v) * hidden);
}
int a11 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 1, c, IH, IW, C);
if (a11 != -1) {
grad[a11] += round_converter(v * u * hidden);
}
}
}
}
}
}
template <typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR_backwardmat(const ctype* src, const float* map_xy,
const ctype* diff, float* grad, int N, int C,
int IH, int IW, int OH, int OW, float scalar) {
RoundingConverter<ctype> round_converter;
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW,
map_xy += OH * OW * 2, grad += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col; // alphaw
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
float hidden = static_cast<float>(
diff[get_offset<format>(h, w, c, OH, OW, C)]);
float du = 0.f, dv = 0.f;
int a00 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 0, c, IH, IW, C);
int a01 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 1, c, IH, IW, C);
int a10 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 0, c, IH, IW, C);
int a11 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 1, c, IH, IW, C);
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
dv -= ((a10 != -1) ? src[a10] : scalar) * u;
dv += ((a11 != -1) ? src[a11] : scalar) * u;
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
du -= ((a01 != -1) ? src[a01] : scalar) * v;
du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
du += ((a11 != -1) ? src[a11] : scalar) * v;
grad[h * OW * 2 + w * 2 + 0] +=
round_converter(hidden * dv);
grad[h * OW * 2 + w * 2 + 1] +=
round_converter(hidden * du);
}
}
}
}
}
} // namespace
......@@ -148,8 +245,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
src.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \
param().scalar, \
Round<ctype, DTypeTrait<dt>::category>::round))); \
param().scalar))); \
break; \
}
......@@ -172,6 +268,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
support_dtype(dtype::Float32);
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16));
MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
support_dtype(dtype::Int8);
support_dtype(dtype::Uint8);
#undef cb
......@@ -181,3 +278,109 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
megdnn_throw("unsupported dtype in remap naive\n");
}
}
void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
int N, C, IH, IW, OH, OW;
N = grad.layout.shape[0];
C = grad.layout.shape[1];
IH = grad.layout.shape[2];
IW = grad.layout.shape[3];
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
switch (diff.layout.dtype.enumv()) {
#define cb(dt, fmt, border, interpolation) \
if (param().format == param::Remap::Format::fmt && \
param().border_type == param::Remap::BorderMode::border && \
param().imode == param::Remap::InterpolationMode::interpolation) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwarddata< \
ctype, param::Remap::Format::fmt, \
param::Remap::BorderMode::border>( \
grad.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW))); \
break; \
}
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT, LINEAR); \
cb(dt, NCHW, REPLICATE, LINEAR); \
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
}
support_dtype(dtype::Float32);
MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
#undef cb
#undef support_dtype
default:
megdnn_throw("unsupported dtype in remap backward naive\n");
}
}
void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, diff.layout, grad.layout,
workspace.size);
megdnn_assert(param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
int N, C, IH, IW, OH, OW;
N = src.layout.shape[0];
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
switch (src.layout.dtype.enumv()) {
#define cb(dt, fmt, border, interpolation) \
if (param().format == param::Remap::Format::fmt && \
param().border_type == param::Remap::BorderMode::border && \
param().imode == param::Remap::InterpolationMode::interpolation) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwardmat< \
ctype, param::Remap::Format::fmt, \
param::Remap::BorderMode::border>( \
src.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), \
grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \
param().scalar))); \
break; \
}
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT, LINEAR); \
cb(dt, NCHW, REPLICATE, LINEAR); \
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
}
support_dtype(dtype::Float32);
MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
#undef cb
#undef support_dtype
default:
megdnn_throw("unsupported dtype in remap backward naive\n");
}
}
// vim: syntax=cpp.doxygen
......@@ -23,6 +23,33 @@ class RemapImpl final : public Remap {
return 0;
}
};
class RemapBackwardDataImpl final : public RemapBackwardData {
public:
using RemapBackwardData::RemapBackwardData;
void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class RemapBackwardMatImpl final : public RemapBackwardMat {
public:
using RemapBackwardMat::RemapBackwardMat;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
......
......@@ -106,6 +106,8 @@ DEF(DeformablePSROIPoolingForward, 5, true, true);
DEF(DeformablePSROIPoolingBackward, 7, true, false);
DEF(BatchConvBiasForward, 5, true, true);
DEF(Remap, 3, true, true);
DEF(RemapBackwardData, 3, true, false);
DEF(RemapBackwardMat, 4, true, false);
} // namespace test
} // namespace megdnn
......
......@@ -46,6 +46,9 @@ static inline std::vector<TestArg> get_nchw_args() {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.border_type = border_type;
args.emplace_back(param, TensorShape{70000, 1, 2, 2},
TensorShape{70000, 2, 2, 2}, TensorShape{70000, 1, 2, 2});
args.emplace_back(param, TensorShape{1, 1, 2, 2},
TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2});
......@@ -90,6 +93,9 @@ static inline std::vector<TestArg> get_nhwc_args() {
param.format = fmt;
param.border_type = border_type;
param.scalar = 12.f;
args.emplace_back(param, TensorShape{70000, 2, 2, 1},
TensorShape{70000, 2, 2, 2}, TensorShape{70000, 2, 2, 1});
args.emplace_back(param, TensorShape{1, 2, 2, 1},
TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1});
......
......@@ -40,6 +40,22 @@ TEST_F(CUDA, REMAP_NCHW_FLOAT) {
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.set_epsilon(1e-2) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::BFloat16(), float_rng);
#undef cb
}
TEST_F(CUDA, REMAP_NCHW_INT) {
......@@ -87,6 +103,22 @@ TEST_F(CUDA, REMAP_NHWC_FLOAT) {
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.set_epsilon(1e-2) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::BFloat16(), float_rng);
#undef cb
}
TEST_F(CUDA, REMAP_NHWC_INT) {
......@@ -114,6 +146,85 @@ TEST_F(CUDA, REMAP_NHWC_INT) {
#undef cb
}
TEST_F(CUDA, REMAP_BACKWARD_DATA) {
Checker<RemapBackwardData> checker(handle_cuda());
std::vector<TestArg> args = get_nchw_args();
UniformFloatRNG float_rng(0, 255);
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(1, data_type) \
.set_dtype(0, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(1, &data_rng) \
.set_rng(0, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.execs({arg.map_xy, arg.dst, arg.src}); \
}
cb(dtype::Float32(), float_rng);
#undef cb
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(1, data_type) \
.set_dtype(0, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(1, &data_rng) \
.set_rng(0, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.set_epsilon(1e-1) \
.execs({arg.map_xy, arg.dst, arg.src}); \
}
cb(dtype::BFloat16(), float_rng);
#undef cb
}
TEST_F(CUDA, REMAP_BACKWARD_MAT) {
Checker<RemapBackwardMat> checker(handle_cuda());
std::vector<TestArg> args = get_nchw_args();
UniformFloatRNG float_rng(0, 255);
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_dtype(3, dtype::Float32()) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_rng(3, &map_rng) \
.set_param(arg.param) \
.set_epsilon(2e-2) \
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
}
cb(dtype::Float32(), float_rng);
#undef cb
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_dtype(3, dtype::Float32()) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_rng(3, &map_rng) \
.set_param(arg.param) \
.set_epsilon(1e-1) \
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
}
cb(dtype::BFloat16(), float_rng);
#undef cb
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_REMAP) {
......@@ -144,13 +255,31 @@ TEST_F(CUDA, BENCHMARK_REMAP) {
.execs(shapes);
auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs(
shapes);
int size = 0;
if (dtype == dtype::Float32{}) {
size = sizeof(float);
printf("float32: ");
} else if (dtype == dtype::Float16{}) {
size = sizeof(dt_float16);
printf("float16: ");
} else if (dtype == dtype::Int8{}) {
size = sizeof(dt_int8);
printf("int8: ");
} else if (dtype == dtype::Uint8{}) {
size = sizeof(dt_uint8);
printf("uint8: ");
}
const TensorShape map_xy = shapes[1];
const TensorShape dst_layout = shapes[2];
float calc_amount = dst_layout.total_nr_elems();
printf("naive={%.3fms, %.3fMflops}, "
"cuda={%.3fms, %.3fMflops}\n",
t1 / RUN, calc_amount / (t1 / RUN * 1000), t2,
calc_amount / (t2 * 1000));
float calc_amount = (dst_layout.total_nr_elems() * (4.f + 1.f) * size +
map_xy.total_nr_elems() * sizeof(float)) /
(1024 * 1024 * 1024);
printf("naive={%.3fms, %.3fGBPS}, "
"cuda={%.3fms, %.3fGBPS}\n",
t1 / RUN, calc_amount / (t1 / RUN) * 1e3, t2,
calc_amount / t2 * 1e3);
};
Param param;
param.imode = param::Remap::InterpolationMode::LINEAR;
......
......@@ -84,6 +84,7 @@ from .nn import (
max_pool2d,
one_hot,
prelu,
remap,
roi_align,
roi_pooling,
softmax,
......
......@@ -705,6 +705,61 @@ def warp_perspective(
)
@wrap_io_tensor
def remap(
inp: Tensor,
map_xy: Tensor,
border_mode: str = "REPLICATE",
scalar: float = 0.0,
interp_mode: str = "LINEAR",
) -> Tensor:
r"""
Applies remap transformation to batched 2D images.
The input images are transformed to the output images by the tensor map_xy.
The output's H and W are same as map_xy's H and W.
:param inp: input image
:param map_xy: (batch, oh, ow, 2) transformation matrix
:param border_mode: pixel extrapolation method. Default: ``"REPLICATE"``
:param scalar: value used in case of a constant border. Default: ``0``
:param interp_mode: interpolation methods. Default: ``"LINEAR"``
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
inp_shape = (1, 1, 4, 4)
inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
map_xy_shape = (1, 2, 2, 2)
map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
[[0., 1.],[0., 1.]]],
dtype=np.float32).reshape(map_xy_shape))
out = F.remap(inp, map_xy)
print(out.numpy())
Outputs:
.. testoutput::
[[[[1. 4.]
[4. 4.]]]]
"""
return mgb.opr.remap(
inp,
map_xy,
border_type=border_mode,
scalar=scalar,
imode=interp_mode,
format="NCHW",
)
@wrap_io_tensor
def eye(
n: int,
......
......@@ -443,4 +443,29 @@ void RemapForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemapForward) {
mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) {
SymbolVar grad =
RemapBackwardData::make(opr.input(1), out_grad[0],
opr.input(0), opr.param());
return grad.node();
} else if (wrt_idx == 1) {
SymbolVar grad =
RemapBackwardMat::make(opr.input(0), opr.input(1),
out_grad[0], opr.param());
return grad.node();
} else
return InvalidGrad::make(opr, wrt_idx);
}
#endif
/* ====================== RemapBackward ====================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData);
MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat);
MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true);
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -97,6 +97,8 @@ namespace opr {
MGB_SEREG_OPR(ResizeBackward, 2);
MGB_SEREG_OPR(Remap, 2);
MGB_SEREG_OPR(RemapBackwardData, 3);
MGB_SEREG_OPR(RemapBackwardMat, 3);
//! current warp affine version
using WarpAffineV1 = opr::WarpAffine;
......
......@@ -74,7 +74,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
}; // namespace opr
};
using WarpPerspective = WarpPerspectiveForward;
MGB_DEFINE_OPR_CLASS(
......@@ -98,7 +98,7 @@ static SymbolVar make(SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff,
const OperatorNodeConfig& config = {});
void scn_do_execute() override;
}; // namespace mgb
};
MGB_DEFINE_OPR_CLASS(
WarpPerspectiveBackwardMat,
......@@ -119,8 +119,7 @@ static SymbolVar make(SymbolVar src, SymbolVar mat, SymbolVar mat_idx,
const OperatorNodeConfig& config = {});
void scn_do_execute() override;
}
;
};
/* ============================= shape infer ============================== */
//! param: src, dst
......@@ -164,8 +163,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
}
;
};
using Resize = ResizeForward;
MGB_DEFINE_OPR_CLASS(ResizeBackward,
......@@ -177,8 +175,7 @@ ResizeBackward(VarNode* out_diff, VarNode* in_for_shape, const Param& param,
static SymbolVar make(SymbolVar out_diff, SymbolVar in_for_shape,
const Param& param = {},
const OperatorNodeConfig& config = {});
}
;
};
MGB_DEFINE_OPR_CLASS(RemapForward,
intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // {
......@@ -192,10 +189,31 @@ static SymbolVar make(SymbolVar in_tensor, SymbolVar map,
private:
void init_output_dtype() override;
}
;
};
using Remap = RemapForward;
MGB_DEFINE_OPR_CLASS(RemapBackwardData,
intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardData>) // {
public:
RemapBackwardData(VarNode *map, VarNode *out_diff,
VarNode *in_for_shape, const Param &param,
const OperatorNodeConfig &config);
static SymbolVar make(SymbolVar map, SymbolVar out_diff,
SymbolVar in_for_shape, const Param &param = {},
const OperatorNodeConfig &config = {});
};
MGB_DEFINE_OPR_CLASS(RemapBackwardMat,
intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardMat>) // {
public:
RemapBackwardMat(VarNode *src, VarNode *map, VarNode *out_diff,
const Param &param, const OperatorNodeConfig &config);
static SymbolVar make(SymbolVar src, SymbolVar map, SymbolVar out_diff,
const Param &param = {}, const OperatorNodeConfig &config = {});
};
/*!
* \brief apply affine transformation to batched 2D images
*
......@@ -238,8 +256,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
}
;
};
using WarpAffine = WarpAffineForward;
} // opr
......
......@@ -640,11 +640,11 @@ TEST(TestOprImgproc, WarpAffineForward) {
}
TEST(TestOprImgproc, Remap_NCHW) {
constexpr size_t N = 2, C = 8;
constexpr size_t N = 2, C = 8, OH = 10, OW = 10;
opr::Remap::Param param;
using Checker = AutoOprChecker<2, 1>;
TensorShape out_shp{N, C, 10, 10};
TensorShape out_shp{N, C, OH, OW};
param.format = opr::Remap::Param::Format::NCHW;
auto make_graph = [&](const Checker::SymInpArray &inputs) ->
Checker::SymOutArray {
......@@ -657,12 +657,34 @@ TEST(TestOprImgproc, Remap_NCHW) {
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {});
};
std::mt19937 rng(next_rand_seed());
auto rand_real = [&](double lo, double hi) {
auto real = rng() / (std::mt19937::max() + 1.0) * (hi - lo) + lo;
if(std::abs(std::round(real) - real) <= 1e-2)
return real + 1e-1;
return real;
};
auto rand_real2 = [&](double range) {
return rand_real(-range, range);
};
auto gen_mat = [&](HostTensorND& mat) {
auto ptr = mat.ptr<float>();
for (size_t i = 0; i < N; ++ i) {
for(size_t j = 0; j < OH * OW * 2; j++) {
//! undifferentiable when map is an integer
ptr[j] = static_cast<float>(rand_real2(20));
}
ptr += OH * OW * 2;
}
mgb_assert(ptr == mat.ptr<float>() + mat.shape().total_nr_elems());
};
Checker::RunOptions opt;
Checker(make_graph, fwd, CompNode::load("cpu1"))
.disable_grad_check()
.run({TensorShape{N, C, 3, 20}, TensorShape{N, 10, 10, 2}}, opt)
.run({TensorShape{N, C, 6, 5}, TensorShape{N, 10, 10, 2}}, opt)
.run({TensorShape{N, C, 20, 20}, TensorShape{N, 10, 10, 2}}, opt);
.set_input_generator(1, gen_mat)
.run({TensorShape{N, C, 3, 20}, TensorShape{N, OH, OW, 2}}, opt)
.run({TensorShape{N, C, 6, 5}, TensorShape{N, OH, OW, 2}}, opt)
.run({TensorShape{N, C, 20, 20}, TensorShape{N, OH, OW, 2}}, opt);
}
TEST(TestOprImgproc, Remap_NHWC) {
......@@ -690,4 +712,5 @@ TEST(TestOprImgproc, Remap_NHWC) {
.run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt)
.run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册