提交 813628e2 编写于 作者: M Megvii Engine Team

feat(opr): add interpolate trilinear

GitOrigin-RevId: 19a96ba58bdf645ceb15655dc2f9d29085c677a7
上级 2937ea0e
...@@ -245,6 +245,35 @@ protected: ...@@ -245,6 +245,35 @@ protected:
size_t workspace_in_bytes); size_t workspace_in_bytes);
}; };
class Resize3DBase : public OperatorBase {
DEF_OPR_PARAM(Resize3D);
DEF_OPR_IMPL(Resize3DBase, OperatorBase, 1, 1);
public:
using InterpolationMode = Param::InterpolationMode;
protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
};
class Resize3DForward : public Resize3DBase {
DEF_OPR_IMPL(Resize3DForward, Resize3DBase, 1, 1);
public:
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;
protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using Resize3D = Resize3DForward;
/** /**
* \brief Remap opr. * \brief Remap opr.
*/ */
......
...@@ -965,6 +965,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) ...@@ -965,6 +965,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'ksize_d', 0, 'ksize_h', 3, 'ksize_w', 3, 'ksize_d', 0, 'ksize_h', 3, 'ksize_w', 3,
'anchor_d', 0, 'anchor_h', 1, 'anchor_w', 1)) 'anchor_d', 0, 'anchor_h', 1, 'anchor_w', 1))
(pdef('Resize3D')
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
.add_enum_alias('Format', 'Convolution3D', default=1)
.add_fields('bool', 'align_corners', 'false'))
(pdef('TopK'). (pdef('TopK').
add_enum( add_enum(
'Mode', 'Mode',
......
...@@ -160,6 +160,7 @@ private: ...@@ -160,6 +160,7 @@ private:
cb(GaussianBlur) \ cb(GaussianBlur) \
cb(Resize) \ cb(Resize) \
cb(ResizeBackward) \ cb(ResizeBackward) \
cb(Resize3D) \
cb(ParamPackConcat) \ cb(ParamPackConcat) \
cb(MaxTensorDiff) \ cb(MaxTensorDiff) \
cb(MaskConvForward) \ cb(MaskConvForward) \
......
...@@ -150,6 +150,7 @@ DEF(GroupNormBackward, 8, true, true); ...@@ -150,6 +150,7 @@ DEF(GroupNormBackward, 8, true, true);
DEF(MaskedFill, 3, false, true); DEF(MaskedFill, 3, false, true);
DEF(MultiHeadAttnForward, 11, true, true); DEF(MultiHeadAttnForward, 11, true, true);
DEF(MultiHeadAttnBackward, 15, true, true); DEF(MultiHeadAttnBackward, 15, true, true);
DEF(Resize3D, 2, true, false);
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -111,6 +111,39 @@ std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord( ...@@ -111,6 +111,39 @@ std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord(
int ResizeBase::get_nearest_src(float scale, int size, int idx) { int ResizeBase::get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1); return std::min(static_cast<int>(idx / scale), size - 1);
} }
void Resize3DBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) {
auto errmsg = [&]() {
return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(
param().format == Param::Format::NCDHW, "Resize3D only support NCDHW");
megdnn_assert(
src.ndim == 5 && dst.ndim == 5, "shape dim mismatch: %s", errmsg().c_str());
megdnn_assert(src.dtype == dst.dtype, "dtype mismatch: %s", errmsg().c_str());
megdnn_assert(
src.shape[0] == dst.shape[0], "batch size mismatch: %s", errmsg().c_str());
megdnn_assert(
src.shape[1] == dst.shape[1], "channel size mismatch: %s",
errmsg().c_str());
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
auto imode = param().imode;
using IMode = param::Resize3D::InterpolationMode;
megdnn_assert(imode == IMode::INTER_LINEAR, "Resize3D only support TriLinear mode");
}
void Resize3D::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -177,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine); ...@@ -177,6 +177,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianBlur); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianBlur);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ResizeBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ResizeBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize3D);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward);
......
...@@ -26,6 +26,15 @@ void backward_data_proxy( ...@@ -26,6 +26,15 @@ void backward_data_proxy(
int C, int IH, int IW, int OH, int OW, cudaStream_t stream); int C, int IH, int IW, int OH, int OW, cudaStream_t stream);
} // namespace resize } // namespace resize
namespace resize3d {
template <typename ctype>
void resize3d_forward(
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW, cudaStream_t stream);
} // namespace resize3d
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -168,4 +168,41 @@ void ResizeImpl::exec( ...@@ -168,4 +168,41 @@ void ResizeImpl::exec(
} }
} }
size_t Resize3DImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
return 0;
}
void Resize3DImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
size_t out_depth = dst.layout.shape[2];
size_t out_height = dst.layout.shape[3];
size_t out_width = dst.layout.shape[4];
size_t in_depth = src.layout.shape[2];
size_t in_height = src.layout.shape[3];
size_t in_width = src.layout.shape[4];
bool align_corners = param().align_corners;
auto stream = cuda_stream(this->handle());
if (src.layout.dtype == dtype::Float32{}) {
resize3d::resize3d_forward(
align_corners, src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
out_height, out_width, stream);
#if !MEGDNN_DISABLE_FLOAT16
} else if (src.layout.dtype == dtype::Float16{}) {
resize3d::resize3d_forward(
align_corners, src.ptr<dt_float16>(), dst.ptr<dt_float16>(),
src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
out_height, out_width, stream);
#endif
} else {
megdnn_throw(ssprintf(
"unsupported dtype: %s for Resize3D", src.layout.dtype.name()));
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -308,6 +308,156 @@ DNN_INC_FLOAT16(INST(dt_float16)) ...@@ -308,6 +308,156 @@ DNN_INC_FLOAT16(INST(dt_float16))
INST(int8_t); INST(int8_t);
#undef INST #undef INST
} // namespace resize } // namespace resize
namespace resize3d {
__device__ __forceinline__ static float pixel_get_src_index(
float scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
float src_idx = scale * (dst_index + 0.5f) - 0.5f;
return src_idx < 0.f ? 0.f : src_idx;
}
}
__device__ __forceinline__ static size_t index_getter(
int n, int c, int d, int h, int w, const int N, const int C, const int D,
const int H, const int W) {
return n * C * D * H * W + c * D * H * W + d * H * W + h * W + w;
}
template <typename ctype>
__global__ void trilinear_forward(
const int num_kernels, const float rdepth, const float rheight,
const float rwidth, const bool align_corners, const ctype* iptr, ctype* optr,
const int N, const int C, const int ID, const int IH, const int IW,
const int OD, const int OH, const int OW) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < num_kernels) {
const int w2 = (index % (OH * OW)) % OW;
const int h2 = (index % (OH * OW)) / OW;
const int t2 = index / (OH * OW);
if (ID == OD && IH == OH && IW == OW) {
const int t1 = t2;
const int h1 = h2;
const int w1 = w2;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; ++c) {
const ctype val =
iptr[index_getter(n, c, t1, h1, w1, N, C, ID, IH, IW)];
optr[index_getter(n, c, t2, h2, w2, N, C, OD, OH, OW)] = val;
}
}
return;
}
const float t1r = pixel_get_src_index(rdepth, t2, align_corners);
const int t1 = t1r;
const int t1p = (t1 < ID - 1) ? 1 : 0;
const float t1lambda = t1r - t1;
const float t0lambda = static_cast<float>(1) - t1lambda;
const float h1r = pixel_get_src_index(rheight, h2, align_corners);
const int h1 = h1r;
const int h1p = (h1 < IH - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = static_cast<float>(1) - h1lambda;
const float w1r = pixel_get_src_index(rwidth, w2, align_corners);
const int w1 = w1r;
const int w1p = (w1 < IW - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = static_cast<float>(1) - w1lambda;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; ++c) {
const float val =
t0lambda *
(h0lambda * (w0lambda * iptr[index_getter(
n, c, t1, h1, w1, N, C,
ID, IH, IW)] +
w1lambda * iptr[index_getter(
n, c, t1, h1, w1 + w1p,
N, C, ID, IH, IW)]) +
h1lambda *
(w0lambda * iptr[index_getter(
n, c, t1, h1 + h1p, w1, N,
C, ID, IH, IW)] +
w1lambda *
iptr[index_getter(
n, c, t1, h1 + h1p, w1 + w1p,
N, C, ID, IH, IW)])) +
t1lambda *
(h0lambda *
(w0lambda * iptr[index_getter(
n, c, t1 + t1p, h1, w1, N,
C, ID, IH, IW)] +
w1lambda *
iptr[index_getter(
n, c, t1 + t1p, h1, w1 + w1p,
N, C, ID, IH, IW)]) +
h1lambda *
(w0lambda * iptr[index_getter(
n, c, t1 + t1p, h1 + h1p,
w1, N, C, ID, IH, IW)] +
w1lambda * iptr[index_getter(
n, c, t1 + t1p, h1 + h1p,
w1 + w1p, N, C, ID, IH,
IW)]));
optr[index_getter(n, c, t2, h2, w2, N, C, OD, OH, OW)] =
static_cast<ctype>(val);
}
}
}
}
__host__ __forceinline__ static float get_scale(
int input_size, int output_size, bool align_corners) {
if (align_corners) {
if (output_size > 1) {
return static_cast<float>(input_size - 1) / (output_size - 1);
} else {
return 0.f;
}
} else {
return static_cast<float>(input_size) / output_size;
}
}
template <typename ctype>
void resize3d_forward(
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW, cudaStream_t stream) {
const size_t num_kernels = OD * OH * OW;
const size_t num_threads = 512;
float rdepth = get_scale(ID, OD, align_corners);
float rheight = get_scale(IH, OH, align_corners);
float rwidth = get_scale(IW, OW, align_corners);
trilinear_forward<ctype>
<<<(num_kernels + num_threads - 1) / num_threads, num_threads, 0, stream>>>(
num_kernels, rdepth, rheight, rwidth, align_corners, iptr, optr, N,
C, ID, IH, IW, OD, OH, OW);
}
#define INST(ctype) \
template void resize3d_forward( \
const bool, const ctype*, ctype*, const int, const int, const int, \
const int, const int, const int, const int, const int, cudaStream_t);
INST(float)
DNN_INC_FLOAT16(INST(dt_float16))
#undef INST
} // namespace resize3d
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -24,6 +24,16 @@ public: ...@@ -24,6 +24,16 @@ public:
const TensorLayout& diff, const TensorLayout& grad) override; const TensorLayout& diff, const TensorLayout& grad) override;
}; };
class Resize3DImpl final : public Resize3D {
public:
using Resize3D::Resize3D;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
};
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -545,4 +545,149 @@ void ResizeBackwardImpl::exec( ...@@ -545,4 +545,149 @@ void ResizeBackwardImpl::exec(
} }
} }
template <typename ctype>
void Resize3DImpl::kern_naive(
const float rdepth, const float rheight, const float rwidth,
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW) {
auto pixel_get_src_index = [](float scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
float src_idx = scale * (dst_index + 0.5f) - 0.5f;
return src_idx < 0.f ? 0.f : src_idx;
}
};
auto i_index = [&](int in, int ic, int id, int ih, int iw) -> int {
return in * C * ID * IH * IW + ic * ID * IH * IW + id * IH * IW + ih * IW + iw;
};
auto o_index = [&](int in, int ic, int id, int ih, int iw) -> int {
return in * C * OD * OH * OW + ic * OD * OH * OW + id * OH * OW + ih * OW + iw;
};
for (int t2 = 0; t2 < OD; ++t2) {
for (int h2 = 0; h2 < OH; ++h2) {
for (int w2 = 0; w2 < OW; ++w2) {
const float t1r = pixel_get_src_index(rdepth, t2, align_corners);
const int t1 = t1r;
const int t1p = (t1 < ID - 1) ? 1 : 0;
const float t1lambda = t1r - t1;
const float t0lambda = static_cast<float>(1) - t1lambda;
const float h1r = pixel_get_src_index(rheight, h2, align_corners);
const int h1 = h1r;
const int h1p = (h1 < IH - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = static_cast<float>(1) - h1lambda;
const float w1r = pixel_get_src_index(rwidth, w2, align_corners);
const int w1 = w1r;
const int w1p = (w1 < IW - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = static_cast<float>(1) - w1lambda;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
const float val =
t0lambda *
(h0lambda *
(w0lambda *
iptr[i_index(
n, c, t1, h1, w1)] +
w1lambda * iptr[i_index(
n, c, t1, h1,
w1 + w1p)]) +
h1lambda *
(w0lambda * iptr[i_index(
n, c, t1, h1 + h1p,
w1)] +
w1lambda * iptr[i_index(
n, c, t1, h1 + h1p,
w1 + w1p)])) +
t1lambda *
(h0lambda *
(w0lambda * iptr[i_index(
n, c, t1 + t1p, h1,
w1)] +
w1lambda * iptr[i_index(
n, c, t1 + t1p, h1,
w1 + w1p)]) +
h1lambda * (w0lambda * iptr[i_index(
n, c, t1 + t1p,
h1 + h1p, w1)] +
w1lambda * iptr[i_index(
n, c, t1 + t1p,
h1 + h1p,
w1 + w1p)]));
optr[o_index(n, c, t2, h2, w2)] = static_cast<ctype>(val);
}
}
}
}
}
}
#define INST(ctype) \
template void Resize3DImpl::kern_naive( \
const float, const float, const float, const bool, const ctype*, ctype*, \
const int, const int, const int, const int, const int, const int, \
const int, const int)
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
void Resize3DImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
bool align_corners = param().align_corners;
size_t N = src.layout.shape[0];
size_t C = src.layout.shape[1];
size_t OD = dst.layout.shape[2];
size_t OH = dst.layout.shape[3];
size_t OW = dst.layout.shape[4];
size_t ID = src.layout.shape[2];
size_t IH = src.layout.shape[3];
size_t IW = src.layout.shape[4];
auto get_scale = [](int input_size, int output_size, bool align_corners) -> float {
if (align_corners) {
if (output_size > 1) {
return static_cast<float>(input_size - 1) / (output_size - 1);
} else {
return 0.f;
}
} else {
return static_cast<float>(input_size) / output_size;
}
};
float rdepth = get_scale(ID, OD, align_corners);
float rheight = get_scale(IH, OH, align_corners);
float rwidth = get_scale(IW, OW, align_corners);
if (src.layout.dtype == dtype::Float32{}) {
Resize3DImpl::kern_naive(
rdepth, rheight, rwidth, align_corners, src.ptr<dt_float32>(),
dst.ptr<dt_float32>(), N, C, ID, IH, IW, OD, OH, OW);
#if !MEGDNN_DISABLE_FLOAT16
} else if (src.layout.dtype == dtype::Float16{}) {
Resize3DImpl::kern_naive(
rdepth, rheight, rwidth, align_corners, src.ptr<dt_float16>(),
dst.ptr<dt_float16>(), N, C, ID, IH, IW, OD, OH, OW);
#endif
} else {
megdnn_throw(ssprintf(
"unsupported dtype: %s for Resize3D", src.layout.dtype.name()));
}
}
size_t Resize3DImpl::get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
return 0;
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -83,6 +83,24 @@ private: ...@@ -83,6 +83,24 @@ private:
int N, int C, int IH, int IW, int OH, int OW); int N, int C, int IH, int IW, int OH, int OW);
}; };
class Resize3DImpl final : public Resize3D {
public:
using Resize3D::Resize3D;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
private:
template <typename ctype>
void kern_naive(
const float rdepth, const float rheight, const float rwidth,
const bool align_corners, const ctype* iptr, ctype* optr, const int N,
const int C, const int ID, const int IH, const int IW, const int OD,
const int OH, const int OW);
};
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
......
...@@ -193,6 +193,26 @@ TEST_F(CUDA, RESIZE_BACKWARD) { ...@@ -193,6 +193,26 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
} }
} }
TEST_F(CUDA, RESIZE3D_NCDHW) {
using IMode = param::Resize3D::InterpolationMode;
using Format = param::Resize3D::Format;
auto ac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, true};
auto nac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, false};
auto run = [&](DType d, TensorShape ishape, TensorShape oshape) {
Checker<Resize3D> checker(handle_cuda());
checker.set_param(ac_param).set_dtype(0, d).set_dtype(1, d).execs(
{ishape, oshape});
checker.set_param(nac_param).execs({ishape, oshape});
};
for (auto&& dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run(dtype, {1, 1, 2, 2, 2}, {1, 1, 4, 4, 4});
run(dtype, {2, 2, 2, 3, 4}, {2, 2, 2, 3, 6});
run(dtype, {2, 2, 2, 3, 4}, {2, 2, 3, 4, 5});
}
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_RESIZE_CV) { TEST_F(CUDA, BENCHMARK_RESIZE_CV) {
......
...@@ -61,3 +61,95 @@ TEST_F(NAIVE, RESIZE_NCHW4) { ...@@ -61,3 +61,95 @@ TEST_F(NAIVE, RESIZE_NCHW4) {
.execs({arg.src, arg.dst}); .execs({arg.src, arg.dst});
} }
} }
TEST_F(NAIVE, RESIZE3D_NCDHW) {
using IMode = param::Resize3D::InterpolationMode;
using Format = param::Resize3D::Format;
auto ac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, true};
auto nac_param = param::Resize3D{IMode::LINEAR, Format::NCDHW, false};
Checker<Resize3D> checker(handle());
checker.set_param(nac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float32(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float32(),
{0., 0.25, 0.75, 1., 0.5, 0.75, 1.25, 1.5, 1.5, 1.75,
2.25, 2.5, 2., 2.25, 2.75, 3., 1., 1.25, 1.75, 2.,
1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3., 3.25,
3.75, 4., 3., 3.25, 3.75, 4., 3.5, 3.75, 4.25, 4.5,
4.5, 4.75, 5.25, 5.5, 5., 5.25, 5.75, 6., 4., 4.25,
4.75, 5., 4.5, 4.75, 5.25, 5.5, 5.5, 5.75, 6.25, 6.5,
6., 6.25, 6.75, 7.})});
checker.set_param(ac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float32(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float32(),
{0., 0.3333333, 0.6666667, 1., 0.6666667,
1., 1.3333333, 1.6666666, 1.3333334, 1.6666667,
1.9999999, 2.3333333, 2., 2.3333333, 2.6666665,
3., 1.3333334, 1.6666666, 2.0000002, 2.3333335,
2., 2.333333, 2.6666667, 2.9999998, 2.6666665,
3., 3.3333333, 3.6666665, 3.3333333, 3.6666665,
4., 4.3333335, 2.6666667, 3., 3.3333337,
3.6666667, 3.3333335, 3.6666663, 4., 4.333333,
3.9999998, 4.333333, 4.6666665, 5., 4.6666665,
5., 5.3333335, 5.666667, 4., 4.333333,
4.666667, 5., 4.6666665, 4.9999995, 5.3333335,
5.6666665, 5.333333, 5.6666665, 6., 6.3333335,
6., 6.333333, 6.666667, 7.})});
checker.set_param(nac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float16(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float16(),
{0., 0.25, 0.75, 1., 0.5, 0.75, 1.25, 1.5, 1.5, 1.75,
2.25, 2.5, 2., 2.25, 2.75, 3., 1., 1.25, 1.75, 2.,
1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3., 3.25,
3.75, 4., 3., 3.25, 3.75, 4., 3.5, 3.75, 4.25, 4.5,
4.5, 4.75, 5.25, 5.5, 5., 5.25, 5.75, 6., 4., 4.25,
4.75, 5., 4.5, 4.75, 5.25, 5.5, 5.5, 5.75, 6.25, 6.5,
6., 6.25, 6.75, 7.})});
checker.set_param(ac_param).exect(
Testcase{
TensorValue(
{1, 1, 2, 2, 2}, dtype::Float16(),
{0., 1., 2., 3., 4., 5., 6., 7.}),
{}},
Testcase{
{},
TensorValue(
{1, 1, 4, 4, 4}, dtype::Float16(),
{0., 0.3333333, 0.6666667, 1., 0.6666667,
1., 1.3333333, 1.6666666, 1.3333334, 1.6666667,
1.9999999, 2.3333333, 2., 2.3333333, 2.6666665,
3., 1.3333334, 1.6666666, 2.0000002, 2.3333335,
2., 2.333333, 2.6666667, 2.9999998, 2.6666665,
3., 3.3333333, 3.6666665, 3.3333333, 3.6666665,
4., 4.3333335, 2.6666667, 3., 3.3333337,
3.6666667, 3.3333335, 3.6666663, 4., 4.333333,
3.9999998, 4.333333, 4.6666665, 5., 4.6666665,
5., 5.3333335, 5.666667, 4., 4.333333,
4.666667, 5., 4.6666665, 4.9999995, 5.3333335,
5.6666665, 5.333333, 5.6666665, 6., 6.3333335,
6., 6.333333, 6.666667, 7.})});
}
...@@ -474,7 +474,8 @@ def interpolate( ...@@ -474,7 +474,8 @@ def interpolate(
size: the size of the output tensor. Default: None size: the size of the output tensor. Default: None
scale_factor: scaling factor of the output tensor. Default: None scale_factor: scaling factor of the output tensor. Default: None
mode: interpolation methods, acceptable values are: mode: interpolation methods, acceptable values are:
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear" "bilinear", "linear", "trilinear", "bicubic" and "nearest". Default: "bilinear"
"trilinear" is valid only when inp is a 5D-tensor
align_corners: This only has an effect when ``mode`` align_corners: This only has an effect when ``mode``
is "bilinear" or "linear". Geometrically, we consider the pixels of the input is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input and output as squares rather than points. If set to ``True``, the input
...@@ -500,9 +501,9 @@ def interpolate( ...@@ -500,9 +501,9 @@ def interpolate(
>>> np.testing.assert_allclose(out.numpy(), out2.numpy()) >>> np.testing.assert_allclose(out.numpy(), out2.numpy())
""" """
mode = mode.lower() mode = mode.lower()
if mode not in ["bilinear", "linear", "bicubic", "nearest"]: if mode not in ["bilinear", "linear", "trilinear", "bicubic", "nearest"]:
raise ValueError("unsupported interpolate mode: {}".format(mode)) raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]: if mode not in ["bilinear", "linear", "trilinear"]:
if align_corners is not None: if align_corners is not None:
raise ValueError( raise ValueError(
"align_corners option can only be set in the bilinear/linear interpolating mode" "align_corners option can only be set in the bilinear/linear interpolating mode"
...@@ -514,14 +515,22 @@ def interpolate( ...@@ -514,14 +515,22 @@ def interpolate(
if mode == "linear": if mode == "linear":
inp = expand_dims(inp, 3) inp = expand_dims(inp, 3)
if inp.ndim != 4: if mode == "trilinear":
raise ValueError("shape of input tensor must correspond to the operartion mode") assert (
inp.ndim == 5
), "under trilinear mode, input tensor must have 5 dimensions"
else:
assert (
inp.ndim == 4
), "shape of input tensor must correspond to the operartion mode"
def get_dsize(scale_factor): def get_dsize(scale_factor):
if isinstance(scale_factor, (float, int)): if isinstance(scale_factor, (float, int)):
scale_factor = float(scale_factor) scale_factor = float(scale_factor)
if mode == "linear": if mode == "linear":
scale_factor = (scale_factor, float(1)) scale_factor = (scale_factor, float(1))
elif mode == "trilinear":
scale_factor = (scale_factor, scale_factor, scale_factor)
else: else:
scale_factor = (scale_factor, scale_factor) scale_factor = (scale_factor, scale_factor)
else: else:
...@@ -530,21 +539,28 @@ def interpolate( ...@@ -530,21 +539,28 @@ def interpolate(
"under linear mode, scale_factor can only be single value" "under linear mode, scale_factor can only be single value"
) )
assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" if mode == "trilinear":
assert isinstance(scale_factor[0], float) and isinstance( assert (
scale_factor[1], float len(scale_factor) == 3
), "scale_factor must be float type" ), f"shape of scale_factor of interpolate-{mode} must be equal to (3, )"
dsize = tuple( else:
assert (
len(scale_factor) == 2
), f"shape of scale_factor of interpolate-{mode} must be equal to (2, )"
assert all(
isinstance(x, (float, int)) for x in scale_factor
), f"scale_factor of interpolate must be float/int type"
dsize = [
floor( floor(
Tensor( Tensor(
inp.shape[i + 2] * scale_factor[i], inp.shape[i + 2] * float(scale_factor[i]),
dtype="float32", dtype="float32",
device=inp.device, device=inp.device,
) )
) )
for i in range(2) for i in range(len(scale_factor))
) ]
dsize = concat([dsize[0], dsize[1]], axis=0) dsize = concat(dsize, axis=0)
return dsize return dsize
if size is None: if size is None:
...@@ -557,13 +573,24 @@ def interpolate( ...@@ -557,13 +573,24 @@ def interpolate(
raise ValueError("scale_factor must be None when size is provided") raise ValueError("scale_factor must be None when size is provided")
if isinstance(size, int): if isinstance(size, int):
size = (size, 1) if mode == "trilinear":
size = (size, 1, 1)
else:
size = (size, 1)
else: else:
if mode == "linear": if mode == "linear":
raise ValueError("under linear mode, size can only be single value") raise ValueError("under linear mode, size can only be single value")
dsize = size dsize = size
if not align_corners: if mode == "trilinear":
if inp.dtype == np.float16:
inp = inp.astype("float32")
op = builtin.Resize3D(
imode="linear", format="NCDHW", align_corners=align_corners
)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(ret,) = apply(op, inp, shape)
elif not align_corners:
# fastpath for interpolate # fastpath for interpolate
mode_map = { mode_map = {
"linear": "linear", "linear": "linear",
......
...@@ -232,7 +232,7 @@ def test_interpolate(): ...@@ -232,7 +232,7 @@ def test_interpolate():
def error_shape_linear_interpolate(): def error_shape_linear_interpolate():
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
with pytest.raises(ValueError): with pytest.raises(AssertionError):
F.vision.interpolate(inp, scale_factor=2.0, mode="linear") F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
def inappropriate_scale_linear_interpolate(): def inappropriate_scale_linear_interpolate():
......
...@@ -465,6 +465,17 @@ def test_resize(): ...@@ -465,6 +465,17 @@ def test_resize():
check_pygraph_dump(fwd, [x], [out]) check_pygraph_dump(fwd, [x], [out])
def test_resize3d():
x = Tensor(np.random.randn(10, 3, 32, 32, 32))
@trace(symbolic=True, capture_as_const=True)
def fwd(x):
return F.vision.interpolate(x, size=(16, 16, 16), mode="trilinear")
out = fwd(x)
check_pygraph_dump(fwd, [x], [out])
def test_index_onehot(): def test_index_onehot():
src = Tensor([[1.0, 2.0]]) src = Tensor([[1.0, 2.0]])
index = Tensor([0]) index = Tensor([0])
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace { namespace resize {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Resize&>(def); auto&& op = static_cast<const Resize&>(def);
...@@ -16,7 +16,21 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -16,7 +16,21 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
} }
OP_TRAIT_REG(Resize, Resize).apply_on_var_node(apply_on_var_node).fallback(); OP_TRAIT_REG(Resize, Resize).apply_on_var_node(apply_on_var_node).fallback();
} // anonymous namespace
} // namespace resize
namespace resize3d {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Resize3D&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Resize3D::make(inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Resize3D, Resize3D).apply_on_var_node(apply_on_var_node).fallback();
} // namespace resize3d
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
20aa8ae7e128c1e24564ce68389307cc ../../dnn/scripts/opr_param_defs.py 29b2127eb4034bf24e473945d70ead4a ../../dnn/scripts/opr_param_defs.py
9e9636d66694dd7d5a7853247a5406f9 ../../src/core/include/megbrain/ir/ops.td 639ff50d64fcb78374de266c88942c2c ../../src/core/include/megbrain/ir/ops.td
e4489c2e1ea2b680d61c352842e56929 generated/opdef.h.inl 16654743e01160eeee879107cc4cac41 generated/opdef.h.inl
fd27534146a1cfcc791e40b2bb532076 generated/opdef.cpp.inl 97c541ed45b0be98f1ac2700f5b4d8a6 generated/opdef.cpp.inl
6754eaa59ef19178eba41e99e418790c generated/opdef.py.inl 6f9c6a7a1d71cca195c1e30743a1f542 generated/opdef.py.inl
df66a3089aa6c12e5b1d943cd3d20e80 generated/opdef.cpy.inl 806c5ceb34f571fc5c9d98d2ca8cad63 generated/opdef.cpy.inl
911001ef0dd771024919f7a1a3a009db generated/enum_macro.h 911001ef0dd771024919f7a1a3a009db generated/enum_macro.h
...@@ -7044,6 +7044,78 @@ OP_TRAIT_REG(Resize, Resize) ...@@ -7044,6 +7044,78 @@ OP_TRAIT_REG(Resize, Resize)
.props(Resize_props_impl) .props(Resize_props_impl)
.make_name(Resize_make_name_impl); .make_name(Resize_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Resize3D);
namespace {
size_t Resize3D_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::hash(op_.align_corners));
return val;
}
bool Resize3D_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Resize3D>(),
&&b_ = rhs_.cast_final_safe<Resize3D>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.imode != b_.imode) return false;
if (a_.format != b_.format) return false;
if (a_.align_corners != b_.align_corners) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Resize3D_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.imode){
case Resize3D::InterpolationMode::NEAREST:
props_.emplace_back("imode", "NEAREST");
break;
case Resize3D::InterpolationMode::LINEAR:
props_.emplace_back("imode", "LINEAR");
break;
case Resize3D::InterpolationMode::AREA:
props_.emplace_back("imode", "AREA");
break;
case Resize3D::InterpolationMode::CUBIC:
props_.emplace_back("imode", "CUBIC");
break;
case Resize3D::InterpolationMode::LANCZOS4:
props_.emplace_back("imode", "LANCZOS4");
break;
default:
props_.emplace_back("imode", "INVALID");
break;
}
switch (op_.format){
case Resize3D::Format::NCDHW:
props_.emplace_back("format", "NCDHW");
break;
case Resize3D::Format::NDHWC:
props_.emplace_back("format", "NDHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
props_.emplace_back("align_corners", std::to_string(op_.align_corners));
return props_;
}
std::string Resize3D_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Resize3D>();
static_cast<void>(op_);
return "Resize3D";
}
} // anonymous namespace
OP_TRAIT_REG(Resize3D, Resize3D)
.hash(Resize3D_hash_impl)
.is_same_st(Resize3D_is_same_st_impl)
.props(Resize3D_props_impl)
.make_name(Resize3D_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD); MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
namespace { namespace {
......
...@@ -20536,6 +20536,169 @@ void _init_py_Resize(py::module m) { ...@@ -20536,6 +20536,169 @@ void _init_py_Resize(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Resize::typeinfo(), &py_type).second); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Resize::typeinfo(), &py_type).second);
} }
void _init_py_Resize3D_InterpolationMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<Resize3D::InterpolationMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "InterpolationMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_Resize3D_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<Resize3D::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(Resize3D) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Resize3D)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"imode", serialization<decltype(opdef.imode)>::dump(opdef.imode)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"align_corners", serialization<decltype(opdef.align_corners)>::dump(opdef.align_corners)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Resize3D)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("imode");
if (iter != state.end()) {
opdef.imode = serialization<decltype(opdef.imode)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
{
auto&& iter = state.find("align_corners");
if (iter != state.end()) {
opdef.align_corners = serialization<decltype(opdef.align_corners)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Resize3D)
int PyOp(Resize3D)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"imode", "format", "align_corners", "scope", NULL};
PyObject *imode = NULL, *format = NULL, *align_corners = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast<char**>(kwlist), &imode, &format, &align_corners, &scope))
return -1;
if (imode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().imode =
py::cast<decltype(Resize3D::imode)>(py::handle(imode));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().format =
py::cast<decltype(Resize3D::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (align_corners) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Resize3D)*>(self)->inst().align_corners =
py::cast<decltype(Resize3D::align_corners)>(py::handle(align_corners));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Resize3D)::py_getsetters[] = {
{const_cast<char*>("imode"), py_get_generic(Resize3D, imode), py_set_generic(Resize3D, imode), const_cast<char*>("imode"), NULL},
{const_cast<char*>("format"), py_get_generic(Resize3D, format), py_set_generic(Resize3D, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("align_corners"), py_get_generic(Resize3D, align_corners), py_set_generic(Resize3D, align_corners), const_cast<char*>("align_corners"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Resize3D)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Resize3D)::getstate, METH_NOARGS, "Resize3D getstate"},
{const_cast<char*>("__setstate__"), PyOp(Resize3D)::setstate, METH_VARARGS, "Resize3D setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Resize3D)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Resize3D)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Resize3D)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Resize3D)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, imode: Union[str, InterpolationMode] = ..., format: Union[str, Format] = ..., align_corners: bool = ...) -> None\n"
};
void _init_py_Resize3D(py::module m) {
using py_op = PyOp(Resize3D);
auto& py_type = PyOpType(Resize3D);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Resize3D";
py_type.tp_basicsize = sizeof(PyOp(Resize3D));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Resize3D";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Resize3D), &PyOp(Resize3D)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_Resize3D_InterpolationMode(py_type);
_init_py_Resize3D_Format(py_type);
PyType_Modified(&py_type);
m.add_object("Resize3D", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Resize3D::typeinfo(), &py_type).second);
}
PyOpDefBegin(SVD) // { PyOpDefBegin(SVD) // {
static PyGetSetDef py_getsetters[]; static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[]; static PyMethodDef tp_methods[];
...@@ -23327,6 +23490,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { ...@@ -23327,6 +23490,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_RemoveAxis(m); \ _init_py_RemoveAxis(m); \
_init_py_Reshape(m); \ _init_py_Reshape(m); \
_init_py_Resize(m); \ _init_py_Resize(m); \
_init_py_Resize3D(m); \
_init_py_SVD(m); \ _init_py_SVD(m); \
_init_py_SetMeshIndexing(m); \ _init_py_SetMeshIndexing(m); \
_init_py_SetSubtensor(m); \ _init_py_SetSubtensor(m); \
......
...@@ -1808,6 +1808,23 @@ public: ...@@ -1808,6 +1808,23 @@ public:
} }
}; };
class Resize3D : public OpDefImplBase<Resize3D> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::Resize3D::InterpolationMode;
using Format = ::megdnn::param::Resize3D::Format;
InterpolationMode imode = ::megdnn::param::Resize3D::InterpolationMode::LINEAR;
Format format = ::megdnn::param::Resize3D::Format::NDHWC;
bool align_corners = false;
Resize3D() = default;
Resize3D(InterpolationMode imode_, Format format_, bool align_corners_, std::string scope_ = {}): imode(imode_), format(format_), align_corners(align_corners_) { set_scope(scope_); }
Resize3D(::megdnn::param::Resize3D packed_param_0): imode(packed_param_0.imode), format(packed_param_0.format), align_corners(packed_param_0.align_corners) {}
::megdnn::param::Resize3D param() const {
return {imode, format, align_corners};
}
};
class SVD : public OpDefImplBase<SVD> { class SVD : public OpDefImplBase<SVD> {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
...@@ -1924,6 +1924,18 @@ ResizeInst ...@@ -1924,6 +1924,18 @@ ResizeInst
.def_readwrite("imode", &Resize::imode) .def_readwrite("imode", &Resize::imode)
.def_readwrite("format", &Resize::format); .def_readwrite("format", &Resize::format);
py::class_<Resize3D, std::shared_ptr<Resize3D>, OpDef> Resize3DInst(m, "Resize3D");
Resize3DInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
Resize3DInst.attr("Format") = Convolution3DInst.attr("Format");
Resize3DInst
.def(py::init<::megdnn::param::Resize3D::InterpolationMode, ::megdnn::param::Resize3D::Format, bool, std::string>(), py::arg("imode") = ::megdnn::param::Resize3D::InterpolationMode::LINEAR, py::arg("format") = ::megdnn::param::Resize3D::Format::NDHWC, py::arg("align_corners") = false, py::arg("scope") = {})
.def_readwrite("imode", &Resize3D::imode)
.def_readwrite("format", &Resize3D::format)
.def_readwrite("align_corners", &Resize3D::align_corners);
py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD"); py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD");
SVDInst SVDInst
......
...@@ -112,6 +112,8 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>; ...@@ -112,6 +112,8 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>; def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def Resize3D: MgbHashableOp<"Resize3D", [Resize3DParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> { def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
let extraArguments = (ins let extraArguments = (ins
MgbI32Attr:$ndim MgbI32Attr:$ndim
......
...@@ -502,6 +502,56 @@ MGB_IMPL_OPR_GRAD(ResizeForward) { ...@@ -502,6 +502,56 @@ MGB_IMPL_OPR_GRAD(ResizeForward) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeBackward);
MEGDNN_OPR_INIT2(ResizeBackward, "resize_bwd", 1, false); MEGDNN_OPR_INIT2(ResizeBackward, "resize_bwd", 1, false);
/* ======================= Resize3DForward ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Resize3DForward);
MEGDNN_OPR_INIT2(Resize3DForward, "resize3d")
void Resize3DForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
outshape_by_symvar_enable(1, 1);
}
void Resize3DForward::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
input(1)->add_layout_constraint_contiguous();
}
void Resize3DForward::outshape_by_symvar_do_get_output_shape(
TensorShape& dest, const ShapeInferInfo& shpinfo) {
TensorShape oshp3d;
cg::copy_tensor_value_to_shape(oshp3d, *shpinfo.shpval_inp_val.at(0));
auto imgshp = shpinfo.shape_inp_shp.at(0);
mgb_assert(
imgshp.ndim == 5 && oshp3d.ndim == 3,
"shape mismatch for Resize3DForward: img=%s out3d=%s",
imgshp.to_string().c_str(), oshp3d.to_string().c_str());
dest = imgshp;
for (int i = 0; i < 3; ++i) {
dest.shape[2 + i] = oshp3d.shape[i];
}
}
void Resize3DForward::init_output_static_infer_desc() {
Super::init_output_static_infer_desc();
init_output_static_infer_desc_workspace(false);
}
void Resize3DForward::scn_do_execute() {
intl::MegDNNOprMethInvoker<megdnn::Resize3D>::exec(megdnn_opr(), this);
}
size_t Resize3DForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return intl::MegDNNOprMethInvoker<megdnn::Resize3D>::get_workspace_in_bytes(
megdnn_opr(), this, input_shapes, output_shapes);
}
void Resize3DForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
/* ======================= WarpAffineForward ======================= */ /* ======================= WarpAffineForward ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpAffineForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpAffineForward);
......
...@@ -195,6 +195,7 @@ MGB_SEREG_OPR(ResizeV2, 2); ...@@ -195,6 +195,7 @@ MGB_SEREG_OPR(ResizeV2, 2);
using DctChannelSelectV1 = opr::DctChannelSelect; using DctChannelSelectV1 = opr::DctChannelSelect;
MGB_SEREG_OPR(DctChannelSelectV1, 0); MGB_SEREG_OPR(DctChannelSelectV1, 0);
MGB_SEREG_OPR(Resize3D, 2);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -206,6 +206,43 @@ public: ...@@ -206,6 +206,43 @@ public:
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
/* ============================= user set shape =========================== */
MGB_DEFINE_OPR_CLASS(
Resize3DForward,
intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr<
mixin::MegDNNOprHolderImpl<megdnn::Resize3DForward>>>) // {
public:
Resize3DForward(
VarNode* in_tensor, VarNode* out_shape, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar in_tensor, SymbolVar out_shape, const Param& param = {},
const OperatorNodeConfig& config = {});
static SymbolVar make(
SymbolVar in_tensor, const TensorShape& out_shape, const Param& param = {},
const OperatorNodeConfig& config = {}) {
return make(
in_tensor, cg::var_from_tensor_shape(in_tensor, out_shape), param,
config);
}
private:
void init_output_dtype() override;
void add_input_layout_constraint() override;
void init_output_static_infer_desc() override;
void outshape_by_symvar_do_get_output_shape(
TensorShape& dest, const ShapeInferInfo& shpinfo) override;
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
};
using Resize3D = Resize3DForward;
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS(
RemapForward, intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // { RemapForward, intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // {
public: public:
......
...@@ -768,6 +768,78 @@ TEST(TestOprImgproc, ResizeBackward) { ...@@ -768,6 +768,78 @@ TEST(TestOprImgproc, ResizeBackward) {
{{10, 8, 8, 4}, {10, 8, 4, 8}}, param, 1e-1, 1e-2); {{10, 8, 8, 4}, {10, 8, 4, 8}}, param, 1e-1, 1e-2);
} }
TEST(TestOprImgproc, Resize3DForward) {
using Param = opr::Resize3D::Param;
using IMode = Param::InterpolationMode;
using Format = Param::Format;
auto ac_param = Param{IMode::LINEAR, Format::NCDHW, true};
auto nac_param = Param{IMode::LINEAR, Format::NCDHW, false};
auto run = [&](TensorShape ishape, TensorShape oshape, std::vector<float> idata,
std::vector<float> oup_ref, Param param, DType test_dtype) {
std::shared_ptr<HostTensorND> inp_host(
new HostTensorND{CompNode::load("xpux"), ishape, test_dtype});
for (size_t i = 0; i < ishape.total_nr_elems(); ++i) {
if (test_dtype == dtype::Float32()) {
inp_host->ptr<dt_float32>()[i] = idata[i];
} else if (test_dtype == dtype::Float16()) {
inp_host->ptr<dt_float16>()[i] = idata[i];
} else {
mgb_assert(false, "invalid");
}
}
std::shared_ptr<HostTensorND> oup_shape_host(new HostTensorND{
CompNode::load("xpux"), TensorShape({oshape.ndim}), dtype::Int32()});
for (size_t i = 0; i < oshape.ndim; ++i) {
oup_shape_host->ptr<dt_int32>()[i] = oshape[i];
}
auto graph = ComputingGraph::make();
auto inp_sym = opr::Host2DeviceCopy::make(*graph, inp_host);
auto oup_shape_sym = opr::Host2DeviceCopy::make(*graph, oup_shape_host);
auto oup = opr::Resize3D::make(inp_sym, oup_shape_sym, param);
HostTensorND oup_host;
auto func = graph->compile({make_callback_copy(oup, oup_host)});
func->execute();
for (size_t i = 0; i < oshape.total_nr_elems(); ++i) {
if (test_dtype == dtype::Float32()) {
MGB_ASSERT_FLOAT_EQ(oup_ref[i], oup_host.ptr<dt_float32>()[i]);
} else if (test_dtype == dtype::Float16()) {
MGB_ASSERT_FLOAT_NEAR(oup_ref[i], oup_host.ptr<dt_float16>()[i], 1e-3);
} else {
mgb_assert(false, "invalid");
}
}
};
for (auto&& test_dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
run({1, 1, 2, 2, 2}, {4, 4, 4}, {0., 1., 2., 3., 4., 5., 6., 7.},
{0., 0.25, 0.75, 1., 0.5, 0.75, 1.25, 1.5, 1.5, 1.75, 2.25,
2.5, 2., 2.25, 2.75, 3., 1., 1.25, 1.75, 2., 1.5, 1.75,
2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3., 3.25, 3.75, 4., 3.,
3.25, 3.75, 4., 3.5, 3.75, 4.25, 4.5, 4.5, 4.75, 5.25, 5.5,
5., 5.25, 5.75, 6., 4., 4.25, 4.75, 5., 4.5, 4.75, 5.25,
5.5, 5.5, 5.75, 6.25, 6.5, 6., 6.25, 6.75, 7.},
nac_param, test_dtype);
run({1, 1, 2, 2, 2}, {4, 4, 4}, {0., 1., 2., 3., 4., 5., 6., 7.},
{0., 0.3333333, 0.6666667, 1., 0.6666667, 1.,
1.3333333, 1.6666666, 1.3333334, 1.6666667, 1.9999999, 2.3333333,
2., 2.3333333, 2.6666665, 3., 1.3333334, 1.6666666,
2.0000002, 2.3333335, 2., 2.333333, 2.6666667, 2.9999998,
2.6666665, 3., 3.3333333, 3.6666665, 3.3333333, 3.6666665,
4., 4.3333335, 2.6666667, 3., 3.3333337, 3.6666667,
3.3333335, 3.6666663, 4., 4.333333, 3.9999998, 4.333333,
4.6666665, 5., 4.6666665, 5., 5.3333335, 5.666667,
4., 4.333333, 4.666667, 5., 4.6666665, 4.9999995,
5.3333335, 5.6666665, 5.333333, 5.6666665, 6., 6.3333335,
6., 6.333333, 6.666667, 7.},
ac_param, test_dtype);
}
}
TEST(TestOprImgproc, WarpAffineForward) { TEST(TestOprImgproc, WarpAffineForward) {
constexpr size_t INP_H = 6, INP_W = 4, N = 2, C = 3; constexpr size_t INP_H = 6, INP_W = 4, N = 2, C = 3;
......
...@@ -127,6 +127,7 @@ union OperatorParam { ...@@ -127,6 +127,7 @@ union OperatorParam {
param.Fill = 93, param.Fill = 93,
param.GeneralNorm=94, param.GeneralNorm=94,
param.MultiHeadAttn=95, param.MultiHeadAttn=95,
param.Resize3D = 96,
} }
table Operator { table Operator {
......
...@@ -144,6 +144,7 @@ union OperatorParam { ...@@ -144,6 +144,7 @@ union OperatorParam {
param.Fill = 93, param.Fill = 93,
param.GeneralNorm=94, param.GeneralNorm=94,
param.MultiHeadAttn=95, param.MultiHeadAttn=95,
param.Resize3D = 96,
} }
table Operator { table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册