From 411253f85256b3a57252f60069caaea1c328741d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Aug 2022 18:46:43 +0800 Subject: [PATCH] feat(mge): implement warp_affine backward GitOrigin-RevId: 7d7261cf6404a742c6af6c7a27ea416b9f5883d2 --- imperative/python/src/grad_override.cpp | 93 +++++ .../test/unit/functional/test_functional.py | 43 +++ imperative/src/impl/ops/warp_perspective.cpp | 38 ++ imperative/tablegen/generated/hash.txt | 10 +- imperative/tablegen/generated/opdef.cpp.inl | 296 +++++++++++++++ imperative/tablegen/generated/opdef.cpy.inl | 344 +++++++++++++++++- imperative/tablegen/generated/opdef.h.inl | 38 ++ imperative/tablegen/generated/opdef.py.inl | 30 ++ src/core/include/megbrain/ir/ops.td | 4 + 9 files changed, 890 insertions(+), 6 deletions(-) create mode 100644 imperative/src/impl/ops/warp_perspective.cpp diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index c576aab25..8e9b5db12 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -587,6 +587,97 @@ std::optional fastpathcopy_grad_rule( return imperative::apply(op, inputs); } +std::optional warp_affine_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& warp_affine = op.cast_final_safe(); + auto&& param = warp_affine.param(); + mgb_assert(inputs.size() == 3); + SmallVector inps; + if (inputs_require_grad[0] || inputs_require_grad[1]) { + for (size_t i = 0; i < inputs.size(); ++i) { + inps.push_back(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inputs = std::move(inps), &warp_affine, + param](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(2); + if (!grad) { + return ret; + } + + CompNodeValue::ref_t device = inputs[0].device(); + DTypeValue::ref_t dtype = inputs[0].dtype(); + HostTensorStorage storage(*device); + storage.ensure_size(3 * (dtype->size())); + + auto* ptr = reinterpret_cast(storage.ptr()); + ptr[0] = 0; + ptr[1] = 0; + ptr[2] = 1; + auto t = imperative::apply( + CreateTensor( + CreateTensor::Unique, *device, dtype::Float32(), + ValueShape({1, 1, 3})), + HostStorage::make(storage))[0]; + auto mat = inputs[1]; + auto&& concat = Concat::make(); + concat->axis = 1; + mat = imperative::apply(*concat, inputs[1], t)[0]; + if (inputs[0]) { + auto&& grad_op = WarpPerspectiveBackwardData::make( + param.imode, param.border_mode, param.format, param.border_val); + ValueRefList args_(inputs.size()); + args_[0] = mat; + args_[1] = grads[0]; + args_[2] = inputs[0]; + ret[0] = imperative::apply(*grad_op, args_)[0]; + } + if (inputs[1]) { + auto&& grad_op = WarpPerspectiveBackwardMat::make( + param.imode, param.border_mode, param.format, param.border_val); + ValueRefList args_(inputs.size()); + args_[0] = inputs[0]; + args_[1] = mat; + args_[2] = grads[0]; + ret[1] = imperative::apply(*grad_op, args_)[0]; + + std::vector> items; + items.push_back(std::make_tuple(1, true, true, false, false)); + auto&& subtensor = Subtensor::make(items); + + CompNodeValue::ref_t device = inputs[0].device(); + DTypeValue::ref_t dtype = inputs[0].dtype(); + int start_idx = 0; + int stop_idx = 2; + auto get_subtensor_index = [&](int idx) { + HostTensorStorage storage(*device); + storage.ensure_size(dtype::Int32().size()); + auto* ptr = reinterpret_cast(storage.ptr()); + ptr[0] = idx; + return imperative::apply( + CreateTensor( + CreateTensor::Unique, *device, dtype::Int32(), + ValueShape({1})), + HostStorage::make(storage))[0]; + }; + auto start = get_subtensor_index(start_idx); + auto stop = get_subtensor_index(stop_idx); + + auto data = ret[1]; + mgb_assert(data); + ret[1] = imperative::apply(*subtensor, data, start, stop)[0]; + } + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); +} + struct Init { Init() { CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); @@ -610,6 +701,8 @@ struct Init { CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); CustomBackward::register_grad_rule( BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); + CustomBackward::register_grad_rule( + WarpAffine::typeinfo(), warp_affine_grad_rule); } } _; diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 87db37cb8..03428aaf0 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -14,6 +14,7 @@ import megengine.core.tensor.dtype as dtype import megengine.functional as F import megengine.jit as jit from megengine import Parameter, Tensor, is_cuda_available, tensor +from megengine.autodiff import GradManager from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad from megengine.core.tensor.utils import make_shape_tuple @@ -571,6 +572,48 @@ def test_warp_perspective(dt): np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt)) +def test_warp_affine_grad(): + dy_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3) + x_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3) + + mat_np_affine = np.array([[[0.5, 0, 0], [0, 0.5, 0],]]).astype("float32") + mat_np_perspective = np.array([[[0.5, 0, 0], [0, 0.5, 0], [0, 0, 1]]]).astype( + "float32" + ) + + dmat_affine = Tensor(np.ones((1, 2, 3), dtype=np.float32)) + dy_affine = Tensor(dy_np) + x_affine = Tensor(x_np) + mat_affine = Tensor(mat_np_affine) + target_shape_affine = x_affine.shape[2:] + + dmat_perspective = Tensor(np.ones((1, 3, 3), dtype=np.float32)) + dy_perspective = Tensor(dy_np) + x_perspective = Tensor(x_np) + mat_perspective = Tensor(mat_np_perspective) + target_shape_perspective = x_perspective.shape[2:] + + gm = GradManager().attach([x_affine, mat_affine, x_perspective, mat_perspective]) + with gm: + y_affine = F.warp_affine( + x_affine, mat_affine, target_shape_affine, format="NCHW" + ) + y_perspective = F.warp_perspective( + x_perspective, mat_perspective, target_shape_perspective + ) + gm.backward([y_affine, y_perspective], [dy_affine, dy_perspective]) + + np.testing.assert_allclose( + x_affine.grad.numpy(), x_perspective.grad.numpy(), rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + mat_affine.grad.numpy(), + mat_perspective.grad.numpy()[0:1, 0:2, 0:3], + rtol=1e-5, + atol=1e-5, + ) + + @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16]) def test_warp_perspective_mat_idx(dt): inp_shape = (2, 1, 4, 4) diff --git a/imperative/src/impl/ops/warp_perspective.cpp b/imperative/src/impl/ops/warp_perspective.cpp new file mode 100644 index 000000000..d812ddf28 --- /dev/null +++ b/imperative/src/impl/ops/warp_perspective.cpp @@ -0,0 +1,38 @@ +#include "../op_trait.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "megbrain/opr/imgproc.h" + +namespace mgb::imperative { + +namespace { +namespace warp_perspective_backward_data { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + mgb_assert(inputs.size() == 3); + auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; + return opr::WarpPerspectiveBackwardData::make( + inputs[0], inputs[1], inputs[2], op.param(), config); +} + +OP_TRAIT_REG(WarpPerspectiveBackwardData, WarpPerspectiveBackwardData) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace warp_perspective_backward_data + +namespace warp_perspective_backward_mat { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + mgb_assert(inputs.size() == 3); + auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; + return opr::WarpPerspectiveBackwardMat::make( + inputs[0], inputs[1], inputs[2], op.param(), config); +} + +OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace warp_perspective_backward_mat +} // namespace + +} // namespace mgb::imperative diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index d0a731196..d8d21eb7c 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py -759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td -2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl -55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl -f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl -6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl +e35e13523f43b7bea4034a0bf75937b7 ../../src/core/include/megbrain/ir/ops.td +240dccd6f8d42cadfd08c6ca90fe61b1 generated/opdef.h.inl +a79a4058ff18ffd9593ee5db3deef6c4 generated/opdef.cpp.inl +83c179ee7416824fbfab978a097cd4d3 generated/opdef.py.inl +86f70b1052331130f5e4c0ca53e68423 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index a4b4adf95..a8ef16aba 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -6941,4 +6941,300 @@ OP_TRAIT_REG(WarpPerspective, WarpPerspective) .props(WarpPerspective_props_impl) .make_name(WarpPerspective_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardData); + +namespace { +size_t WarpPerspectiveBackwardData_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.bmode)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.border_val)); + return val; +} +bool WarpPerspectiveBackwardData_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.imode != b_.imode) return false; + if (a_.bmode != b_.bmode) return false; + if (a_.format != b_.format) return false; + if (a_.border_val != b_.border_val) return false; + return true; +} +std::vector> WarpPerspectiveBackwardData_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + switch (op_.imode){ + case WarpPerspectiveBackwardData::InterpolationMode::NEAREST: + props_.emplace_back("imode", "NEAREST"); + break; + case WarpPerspectiveBackwardData::InterpolationMode::LINEAR: + props_.emplace_back("imode", "LINEAR"); + break; + case WarpPerspectiveBackwardData::InterpolationMode::AREA: + props_.emplace_back("imode", "AREA"); + break; + case WarpPerspectiveBackwardData::InterpolationMode::CUBIC: + props_.emplace_back("imode", "CUBIC"); + break; + case WarpPerspectiveBackwardData::InterpolationMode::LANCZOS4: + props_.emplace_back("imode", "LANCZOS4"); + break; + default: + props_.emplace_back("imode", "INVALID"); + break; + } + switch (op_.bmode){ + case WarpPerspectiveBackwardData::BorderMode::REPLICATE: + props_.emplace_back("bmode", "REPLICATE"); + break; + case WarpPerspectiveBackwardData::BorderMode::REFLECT: + props_.emplace_back("bmode", "REFLECT"); + break; + case WarpPerspectiveBackwardData::BorderMode::REFLECT_101: + props_.emplace_back("bmode", "REFLECT_101"); + break; + case WarpPerspectiveBackwardData::BorderMode::WRAP: + props_.emplace_back("bmode", "WRAP"); + break; + case WarpPerspectiveBackwardData::BorderMode::CONSTANT: + props_.emplace_back("bmode", "CONSTANT"); + break; + case WarpPerspectiveBackwardData::BorderMode::TRANSPARENT: + props_.emplace_back("bmode", "TRANSPARENT"); + break; + case WarpPerspectiveBackwardData::BorderMode::ISOLATED: + props_.emplace_back("bmode", "ISOLATED"); + break; + default: + props_.emplace_back("bmode", "INVALID"); + break; + } + switch (op_.format){ + case WarpPerspectiveBackwardData::Format::NCHW: + props_.emplace_back("format", "NCHW"); + break; + case WarpPerspectiveBackwardData::Format::NHWC: + props_.emplace_back("format", "NHWC"); + break; + case WarpPerspectiveBackwardData::Format::NHWCD4: + props_.emplace_back("format", "NHWCD4"); + break; + case WarpPerspectiveBackwardData::Format::NCHW4: + props_.emplace_back("format", "NCHW4"); + break; + case WarpPerspectiveBackwardData::Format::NCHW8: + props_.emplace_back("format", "NCHW8"); + break; + case WarpPerspectiveBackwardData::Format::NCHW32: + props_.emplace_back("format", "NCHW32"); + break; + case WarpPerspectiveBackwardData::Format::NCHW88: + props_.emplace_back("format", "NCHW88"); + break; + case WarpPerspectiveBackwardData::Format::NCHW44: + props_.emplace_back("format", "NCHW44"); + break; + case WarpPerspectiveBackwardData::Format::NCHW44_DOT: + props_.emplace_back("format", "NCHW44_DOT"); + break; + case WarpPerspectiveBackwardData::Format::NCHW4_NCHW32: + props_.emplace_back("format", "NCHW4_NCHW32"); + break; + case WarpPerspectiveBackwardData::Format::NCHW32_NCHW4: + props_.emplace_back("format", "NCHW32_NCHW4"); + break; + case WarpPerspectiveBackwardData::Format::NCHW4_NCHW: + props_.emplace_back("format", "NCHW4_NCHW"); + break; + case WarpPerspectiveBackwardData::Format::NHWC_NCHW: + props_.emplace_back("format", "NHWC_NCHW"); + break; + case WarpPerspectiveBackwardData::Format::NHWC_NCHW4_IC_SMALL: + props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); + break; + case WarpPerspectiveBackwardData::Format::NCHW_NCHW4_IC_SMALL: + props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); + break; + case WarpPerspectiveBackwardData::Format::CHWN4: + props_.emplace_back("format", "CHWN4"); + break; + case WarpPerspectiveBackwardData::Format::NCHW64: + props_.emplace_back("format", "NCHW64"); + break; + case WarpPerspectiveBackwardData::Format::NCHW4_NHWC: + props_.emplace_back("format", "NCHW4_NHWC"); + break; + default: + props_.emplace_back("format", "INVALID"); + break; + } + props_.emplace_back("border_val", std::to_string(op_.border_val)); + return props_; +} +std::string WarpPerspectiveBackwardData_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "WarpPerspectiveBackwardData"; +} +} // anonymous namespace +OP_TRAIT_REG(WarpPerspectiveBackwardData, WarpPerspectiveBackwardData) + .hash(WarpPerspectiveBackwardData_hash_impl) + .is_same_st(WarpPerspectiveBackwardData_is_same_st_impl) + .props(WarpPerspectiveBackwardData_props_impl) + .make_name(WarpPerspectiveBackwardData_make_name_impl); + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardMat); + +namespace { +size_t WarpPerspectiveBackwardMat_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.bmode)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.border_val)); + return val; +} +bool WarpPerspectiveBackwardMat_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.imode != b_.imode) return false; + if (a_.bmode != b_.bmode) return false; + if (a_.format != b_.format) return false; + if (a_.border_val != b_.border_val) return false; + return true; +} +std::vector> WarpPerspectiveBackwardMat_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + switch (op_.imode){ + case WarpPerspectiveBackwardMat::InterpolationMode::NEAREST: + props_.emplace_back("imode", "NEAREST"); + break; + case WarpPerspectiveBackwardMat::InterpolationMode::LINEAR: + props_.emplace_back("imode", "LINEAR"); + break; + case WarpPerspectiveBackwardMat::InterpolationMode::AREA: + props_.emplace_back("imode", "AREA"); + break; + case WarpPerspectiveBackwardMat::InterpolationMode::CUBIC: + props_.emplace_back("imode", "CUBIC"); + break; + case WarpPerspectiveBackwardMat::InterpolationMode::LANCZOS4: + props_.emplace_back("imode", "LANCZOS4"); + break; + default: + props_.emplace_back("imode", "INVALID"); + break; + } + switch (op_.bmode){ + case WarpPerspectiveBackwardMat::BorderMode::REPLICATE: + props_.emplace_back("bmode", "REPLICATE"); + break; + case WarpPerspectiveBackwardMat::BorderMode::REFLECT: + props_.emplace_back("bmode", "REFLECT"); + break; + case WarpPerspectiveBackwardMat::BorderMode::REFLECT_101: + props_.emplace_back("bmode", "REFLECT_101"); + break; + case WarpPerspectiveBackwardMat::BorderMode::WRAP: + props_.emplace_back("bmode", "WRAP"); + break; + case WarpPerspectiveBackwardMat::BorderMode::CONSTANT: + props_.emplace_back("bmode", "CONSTANT"); + break; + case WarpPerspectiveBackwardMat::BorderMode::TRANSPARENT: + props_.emplace_back("bmode", "TRANSPARENT"); + break; + case WarpPerspectiveBackwardMat::BorderMode::ISOLATED: + props_.emplace_back("bmode", "ISOLATED"); + break; + default: + props_.emplace_back("bmode", "INVALID"); + break; + } + switch (op_.format){ + case WarpPerspectiveBackwardMat::Format::NCHW: + props_.emplace_back("format", "NCHW"); + break; + case WarpPerspectiveBackwardMat::Format::NHWC: + props_.emplace_back("format", "NHWC"); + break; + case WarpPerspectiveBackwardMat::Format::NHWCD4: + props_.emplace_back("format", "NHWCD4"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW4: + props_.emplace_back("format", "NCHW4"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW8: + props_.emplace_back("format", "NCHW8"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW32: + props_.emplace_back("format", "NCHW32"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW88: + props_.emplace_back("format", "NCHW88"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW44: + props_.emplace_back("format", "NCHW44"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW44_DOT: + props_.emplace_back("format", "NCHW44_DOT"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW4_NCHW32: + props_.emplace_back("format", "NCHW4_NCHW32"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW32_NCHW4: + props_.emplace_back("format", "NCHW32_NCHW4"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW4_NCHW: + props_.emplace_back("format", "NCHW4_NCHW"); + break; + case WarpPerspectiveBackwardMat::Format::NHWC_NCHW: + props_.emplace_back("format", "NHWC_NCHW"); + break; + case WarpPerspectiveBackwardMat::Format::NHWC_NCHW4_IC_SMALL: + props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW_NCHW4_IC_SMALL: + props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); + break; + case WarpPerspectiveBackwardMat::Format::CHWN4: + props_.emplace_back("format", "CHWN4"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW64: + props_.emplace_back("format", "NCHW64"); + break; + case WarpPerspectiveBackwardMat::Format::NCHW4_NHWC: + props_.emplace_back("format", "NCHW4_NHWC"); + break; + default: + props_.emplace_back("format", "INVALID"); + break; + } + props_.emplace_back("border_val", std::to_string(op_.border_val)); + return props_; +} +std::string WarpPerspectiveBackwardMat_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "WarpPerspectiveBackwardMat"; +} +} // anonymous namespace +OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat) + .hash(WarpPerspectiveBackwardMat_hash_impl) + .is_same_st(WarpPerspectiveBackwardMat_is_same_st_impl) + .props(WarpPerspectiveBackwardMat_props_impl) + .make_name(WarpPerspectiveBackwardMat_make_name_impl); + // clang-format on diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 687043322..fa1fa9fce 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -18185,6 +18185,346 @@ void _init_py_WarpPerspective(py::module m) { m.add_object("WarpPerspective", reinterpret_cast(&py_type)); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspective::typeinfo(), &py_type).second); } + +void _init_py_WarpPerspectiveBackwardData_InterpolationMode(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "InterpolationMode", reinterpret_cast(e_type)) >= 0); +} + +void _init_py_WarpPerspectiveBackwardData_BorderMode(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "BorderMode", reinterpret_cast(e_type)) >= 0); +} + +void _init_py_WarpPerspectiveBackwardData_Format(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "Format", reinterpret_cast(e_type)) >= 0); +} + +PyOpDefBegin(WarpPerspectiveBackwardData) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"imode", serialization::dump(opdef.imode)}, + {"bmode", serialization::dump(opdef.bmode)}, + {"format", serialization::dump(opdef.format)}, + {"border_val", serialization::dump(opdef.border_val)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("imode"); + if (iter != state.end()) { + opdef.imode = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("bmode"); + if (iter != state.end()) { + opdef.bmode = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("format"); + if (iter != state.end()) { + opdef.format = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("border_val"); + if (iter != state.end()) { + opdef.border_val = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd(WarpPerspectiveBackwardData) + +int PyOp(WarpPerspectiveBackwardData)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"imode", "bmode", "format", "border_val", "scope", NULL}; + PyObject *imode = NULL, *bmode = NULL, *format = NULL, *border_val = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast(kwlist), &imode, &bmode, &format, &border_val, &scope)) + return -1; + + if (imode) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().imode = + py::cast(py::handle(imode)); + } CATCH_ALL(-1) + } + + if (bmode) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().bmode = + py::cast(py::handle(bmode)); + } CATCH_ALL(-1) + } + + if (format) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().format = + py::cast(py::handle(format)); + } CATCH_ALL(-1) + } + + if (border_val) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().border_val = + py::cast(py::handle(border_val)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(WarpPerspectiveBackwardData)::py_getsetters[] = { + {const_cast("imode"), py_get_generic(WarpPerspectiveBackwardData, imode), py_set_generic(WarpPerspectiveBackwardData, imode), const_cast("imode"), NULL}, + {const_cast("bmode"), py_get_generic(WarpPerspectiveBackwardData, bmode), py_set_generic(WarpPerspectiveBackwardData, bmode), const_cast("bmode"), NULL}, + {const_cast("format"), py_get_generic(WarpPerspectiveBackwardData, format), py_set_generic(WarpPerspectiveBackwardData, format), const_cast("format"), NULL}, + {const_cast("border_val"), py_get_generic(WarpPerspectiveBackwardData, border_val), py_set_generic(WarpPerspectiveBackwardData, border_val), const_cast("border_val"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(WarpPerspectiveBackwardData)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(WarpPerspectiveBackwardData)::getstate, METH_NOARGS, "WarpPerspectiveBackwardData getstate"}, + {const_cast("__setstate__"), PyOp(WarpPerspectiveBackwardData)::setstate, METH_VARARGS, "WarpPerspectiveBackwardData setstate"}, + {NULL} /* Sentinel */ + }; + +void _init_py_WarpPerspectiveBackwardData(py::module m) { + using py_op = PyOp(WarpPerspectiveBackwardData); + auto& py_type = PyOpType(WarpPerspectiveBackwardData); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.WarpPerspectiveBackwardData"; + py_type.tp_basicsize = sizeof(PyOp(WarpPerspectiveBackwardData)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "WarpPerspectiveBackwardData"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + _init_py_WarpPerspectiveBackwardData_InterpolationMode(py_type); + _init_py_WarpPerspectiveBackwardData_BorderMode(py_type); + _init_py_WarpPerspectiveBackwardData_Format(py_type); + + PyType_Modified(&py_type); + m.add_object("WarpPerspectiveBackwardData", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardData::typeinfo(), &py_type).second); +} + +void _init_py_WarpPerspectiveBackwardMat_InterpolationMode(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "InterpolationMode", reinterpret_cast(e_type)) >= 0); +} + +void _init_py_WarpPerspectiveBackwardMat_BorderMode(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "BorderMode", reinterpret_cast(e_type)) >= 0); +} + +void _init_py_WarpPerspectiveBackwardMat_Format(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "Format", reinterpret_cast(e_type)) >= 0); +} + +PyOpDefBegin(WarpPerspectiveBackwardMat) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"imode", serialization::dump(opdef.imode)}, + {"bmode", serialization::dump(opdef.bmode)}, + {"format", serialization::dump(opdef.format)}, + {"border_val", serialization::dump(opdef.border_val)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("imode"); + if (iter != state.end()) { + opdef.imode = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("bmode"); + if (iter != state.end()) { + opdef.bmode = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("format"); + if (iter != state.end()) { + opdef.format = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("border_val"); + if (iter != state.end()) { + opdef.border_val = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd(WarpPerspectiveBackwardMat) + +int PyOp(WarpPerspectiveBackwardMat)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"imode", "bmode", "format", "border_val", "scope", NULL}; + PyObject *imode = NULL, *bmode = NULL, *format = NULL, *border_val = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast(kwlist), &imode, &bmode, &format, &border_val, &scope)) + return -1; + + if (imode) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().imode = + py::cast(py::handle(imode)); + } CATCH_ALL(-1) + } + + if (bmode) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().bmode = + py::cast(py::handle(bmode)); + } CATCH_ALL(-1) + } + + if (format) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().format = + py::cast(py::handle(format)); + } CATCH_ALL(-1) + } + + if (border_val) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().border_val = + py::cast(py::handle(border_val)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(WarpPerspectiveBackwardMat)::py_getsetters[] = { + {const_cast("imode"), py_get_generic(WarpPerspectiveBackwardMat, imode), py_set_generic(WarpPerspectiveBackwardMat, imode), const_cast("imode"), NULL}, + {const_cast("bmode"), py_get_generic(WarpPerspectiveBackwardMat, bmode), py_set_generic(WarpPerspectiveBackwardMat, bmode), const_cast("bmode"), NULL}, + {const_cast("format"), py_get_generic(WarpPerspectiveBackwardMat, format), py_set_generic(WarpPerspectiveBackwardMat, format), const_cast("format"), NULL}, + {const_cast("border_val"), py_get_generic(WarpPerspectiveBackwardMat, border_val), py_set_generic(WarpPerspectiveBackwardMat, border_val), const_cast("border_val"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(WarpPerspectiveBackwardMat)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(WarpPerspectiveBackwardMat)::getstate, METH_NOARGS, "WarpPerspectiveBackwardMat getstate"}, + {const_cast("__setstate__"), PyOp(WarpPerspectiveBackwardMat)::setstate, METH_VARARGS, "WarpPerspectiveBackwardMat setstate"}, + {NULL} /* Sentinel */ + }; + +void _init_py_WarpPerspectiveBackwardMat(py::module m) { + using py_op = PyOp(WarpPerspectiveBackwardMat); + auto& py_type = PyOpType(WarpPerspectiveBackwardMat); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.WarpPerspectiveBackwardMat"; + py_type.tp_basicsize = sizeof(PyOp(WarpPerspectiveBackwardMat)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "WarpPerspectiveBackwardMat"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + _init_py_WarpPerspectiveBackwardMat_InterpolationMode(py_type); + _init_py_WarpPerspectiveBackwardMat_BorderMode(py_type); + _init_py_WarpPerspectiveBackwardMat_Format(py_type); + + PyType_Modified(&py_type); + m.add_object("WarpPerspectiveBackwardMat", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardMat::typeinfo(), &py_type).second); +} #define INIT_ALL_OP(m) \ _init_py_AdaptivePooling(m); \ _init_py_AddAxis(m); \ @@ -18290,5 +18630,7 @@ void _init_py_WarpPerspective(py::module m) { _init_py_TypeCvt(m); \ _init_py_UniformRNG(m); \ _init_py_WarpAffine(m); \ - _init_py_WarpPerspective(m); + _init_py_WarpPerspective(m); \ + _init_py_WarpPerspectiveBackwardData(m); \ + _init_py_WarpPerspectiveBackwardMat(m); // clang-format on diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index cd51a5724..47d78f116 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1795,4 +1795,42 @@ public: } }; +class WarpPerspectiveBackwardData : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode; + using BorderMode = ::megdnn::param::WarpPerspective::BorderMode; + using Format = ::megdnn::param::WarpPerspective::Format; + InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR; + BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE; + Format format = ::megdnn::param::WarpPerspective::Format::NCHW; + float border_val = .0f; + WarpPerspectiveBackwardData() = default; + WarpPerspectiveBackwardData(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); } + WarpPerspectiveBackwardData(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {} + ::megdnn::param::WarpPerspective param() const { + return {imode, bmode, format, border_val}; + } +}; + +class WarpPerspectiveBackwardMat : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode; + using BorderMode = ::megdnn::param::WarpPerspective::BorderMode; + using Format = ::megdnn::param::WarpPerspective::Format; + InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR; + BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE; + Format format = ::megdnn::param::WarpPerspective::Format::NCHW; + float border_val = .0f; + WarpPerspectiveBackwardMat() = default; + WarpPerspectiveBackwardMat(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); } + WarpPerspectiveBackwardMat(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {} + ::megdnn::param::WarpPerspective param() const { + return {imode, bmode, format, border_val}; + } +}; + // clang-format on diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 631639c75..696a6e0da 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1858,4 +1858,34 @@ WarpPerspectiveInst .def_readwrite("format", &WarpPerspective::format) .def_readwrite("border_val", &WarpPerspective::border_val); +py::class_, OpDef> WarpPerspectiveBackwardDataInst(m, "WarpPerspectiveBackwardData"); + +WarpPerspectiveBackwardDataInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode"); + +WarpPerspectiveBackwardDataInst.attr("BorderMode") = RemapInst.attr("BorderMode"); + +WarpPerspectiveBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format"); + +WarpPerspectiveBackwardDataInst + .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {}) + .def_readwrite("imode", &WarpPerspectiveBackwardData::imode) + .def_readwrite("bmode", &WarpPerspectiveBackwardData::bmode) + .def_readwrite("format", &WarpPerspectiveBackwardData::format) + .def_readwrite("border_val", &WarpPerspectiveBackwardData::border_val); + +py::class_, OpDef> WarpPerspectiveBackwardMatInst(m, "WarpPerspectiveBackwardMat"); + +WarpPerspectiveBackwardMatInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode"); + +WarpPerspectiveBackwardMatInst.attr("BorderMode") = RemapInst.attr("BorderMode"); + +WarpPerspectiveBackwardMatInst.attr("Format") = AdaptivePoolingInst.attr("Format"); + +WarpPerspectiveBackwardMatInst + .def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {}) + .def_readwrite("imode", &WarpPerspectiveBackwardMat::imode) + .def_readwrite("bmode", &WarpPerspectiveBackwardMat::bmode) + .def_readwrite("format", &WarpPerspectiveBackwardMat::format) + .def_readwrite("border_val", &WarpPerspectiveBackwardMat::border_val); + // clang-format on diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index d70cd20b8..a7206006c 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -104,6 +104,10 @@ def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>; +def WarpPerspectiveBackwardData: MgbHashableOp<"WarpPerspectiveBackwardData", [WarpPerspectiveParam]>; + +def WarpPerspectiveBackwardMat: MgbHashableOp<"WarpPerspectiveBackwardMat", [WarpPerspectiveParam]>; + def Remap: MgbHashableOp<"Remap", [RemapParam]>; def Resize: MgbHashableOp<"Resize", [ResizeParam]>; -- GitLab