提交 411253f8 编写于 作者: M Megvii Engine Team

feat(mge): implement warp_affine backward

GitOrigin-RevId: 7d7261cf6404a742c6af6c7a27ea416b9f5883d2
上级 496070cf
......@@ -587,6 +587,97 @@ std::optional<ValueRefList> fastpathcopy_grad_rule(
return imperative::apply(op, inputs);
}
std::optional<ValueRefList> warp_affine_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& warp_affine = op.cast_final_safe<WarpAffine>();
auto&& param = warp_affine.param();
mgb_assert(inputs.size() == 3);
SmallVector<ValueRef> inps;
if (inputs_require_grad[0] || inputs_require_grad[1]) {
for (size_t i = 0; i < inputs.size(); ++i) {
inps.push_back(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inps), &warp_affine,
param](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(2);
if (!grad) {
return ret;
}
CompNodeValue::ref_t device = inputs[0].device();
DTypeValue::ref_t dtype = inputs[0].dtype();
HostTensorStorage storage(*device);
storage.ensure_size(3 * (dtype->size()));
auto* ptr = reinterpret_cast<dt_float32*>(storage.ptr());
ptr[0] = 0;
ptr[1] = 0;
ptr[2] = 1;
auto t = imperative::apply(
CreateTensor(
CreateTensor::Unique, *device, dtype::Float32(),
ValueShape({1, 1, 3})),
HostStorage::make(storage))[0];
auto mat = inputs[1];
auto&& concat = Concat::make();
concat->axis = 1;
mat = imperative::apply(*concat, inputs[1], t)[0];
if (inputs[0]) {
auto&& grad_op = WarpPerspectiveBackwardData::make(
param.imode, param.border_mode, param.format, param.border_val);
ValueRefList args_(inputs.size());
args_[0] = mat;
args_[1] = grads[0];
args_[2] = inputs[0];
ret[0] = imperative::apply(*grad_op, args_)[0];
}
if (inputs[1]) {
auto&& grad_op = WarpPerspectiveBackwardMat::make(
param.imode, param.border_mode, param.format, param.border_val);
ValueRefList args_(inputs.size());
args_[0] = inputs[0];
args_[1] = mat;
args_[2] = grads[0];
ret[1] = imperative::apply(*grad_op, args_)[0];
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
items.push_back(std::make_tuple(1, true, true, false, false));
auto&& subtensor = Subtensor::make(items);
CompNodeValue::ref_t device = inputs[0].device();
DTypeValue::ref_t dtype = inputs[0].dtype();
int start_idx = 0;
int stop_idx = 2;
auto get_subtensor_index = [&](int idx) {
HostTensorStorage storage(*device);
storage.ensure_size(dtype::Int32().size());
auto* ptr = reinterpret_cast<dt_int32*>(storage.ptr());
ptr[0] = idx;
return imperative::apply(
CreateTensor(
CreateTensor::Unique, *device, dtype::Int32(),
ValueShape({1})),
HostStorage::make(storage))[0];
};
auto start = get_subtensor_index(start_idx);
auto stop = get_subtensor_index(stop_idx);
auto data = ret[1];
mgb_assert(data);
ret[1] = imperative::apply(*subtensor, data, start, stop)[0];
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}
struct Init {
Init() {
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
......@@ -610,6 +701,8 @@ struct Init {
CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule);
CustomBackward::register_grad_rule(
BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule);
CustomBackward::register_grad_rule(
WarpAffine::typeinfo(), warp_affine_grad_rule);
}
} _;
......
......@@ -14,6 +14,7 @@ import megengine.core.tensor.dtype as dtype
import megengine.functional as F
import megengine.jit as jit
from megengine import Parameter, Tensor, is_cuda_available, tensor
from megengine.autodiff import GradManager
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.utils import make_shape_tuple
......@@ -571,6 +572,48 @@ def test_warp_perspective(dt):
np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))
def test_warp_affine_grad():
dy_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3)
x_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3)
mat_np_affine = np.array([[[0.5, 0, 0], [0, 0.5, 0],]]).astype("float32")
mat_np_perspective = np.array([[[0.5, 0, 0], [0, 0.5, 0], [0, 0, 1]]]).astype(
"float32"
)
dmat_affine = Tensor(np.ones((1, 2, 3), dtype=np.float32))
dy_affine = Tensor(dy_np)
x_affine = Tensor(x_np)
mat_affine = Tensor(mat_np_affine)
target_shape_affine = x_affine.shape[2:]
dmat_perspective = Tensor(np.ones((1, 3, 3), dtype=np.float32))
dy_perspective = Tensor(dy_np)
x_perspective = Tensor(x_np)
mat_perspective = Tensor(mat_np_perspective)
target_shape_perspective = x_perspective.shape[2:]
gm = GradManager().attach([x_affine, mat_affine, x_perspective, mat_perspective])
with gm:
y_affine = F.warp_affine(
x_affine, mat_affine, target_shape_affine, format="NCHW"
)
y_perspective = F.warp_perspective(
x_perspective, mat_perspective, target_shape_perspective
)
gm.backward([y_affine, y_perspective], [dy_affine, dy_perspective])
np.testing.assert_allclose(
x_affine.grad.numpy(), x_perspective.grad.numpy(), rtol=1e-5, atol=1e-5
)
np.testing.assert_allclose(
mat_affine.grad.numpy(),
mat_perspective.grad.numpy()[0:1, 0:2, 0:3],
rtol=1e-5,
atol=1e-5,
)
@pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
def test_warp_perspective_mat_idx(dt):
inp_shape = (2, 1, 4, 4)
......
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/imgproc.h"
namespace mgb::imperative {
namespace {
namespace warp_perspective_backward_data {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 3);
auto&& op = static_cast<const WarpPerspectiveBackwardData&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::WarpPerspectiveBackwardData::make(
inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(WarpPerspectiveBackwardData, WarpPerspectiveBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace warp_perspective_backward_data
namespace warp_perspective_backward_mat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 3);
auto&& op = static_cast<const WarpPerspectiveBackwardMat&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::WarpPerspectiveBackwardMat::make(
inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace warp_perspective_backward_mat
} // namespace
} // namespace mgb::imperative
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td
2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl
55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl
f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl
6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl
e35e13523f43b7bea4034a0bf75937b7 ../../src/core/include/megbrain/ir/ops.td
240dccd6f8d42cadfd08c6ca90fe61b1 generated/opdef.h.inl
a79a4058ff18ffd9593ee5db3deef6c4 generated/opdef.cpp.inl
83c179ee7416824fbfab978a097cd4d3 generated/opdef.py.inl
86f70b1052331130f5e4c0ca53e68423 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -6941,4 +6941,300 @@ OP_TRAIT_REG(WarpPerspective, WarpPerspective)
.props(WarpPerspective_props_impl)
.make_name(WarpPerspective_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardData);
namespace {
size_t WarpPerspectiveBackwardData_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardData>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.bmode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::hash(op_.border_val));
return val;
}
bool WarpPerspectiveBackwardData_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<WarpPerspectiveBackwardData>(),
&&b_ = rhs_.cast_final_safe<WarpPerspectiveBackwardData>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.imode != b_.imode) return false;
if (a_.bmode != b_.bmode) return false;
if (a_.format != b_.format) return false;
if (a_.border_val != b_.border_val) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> WarpPerspectiveBackwardData_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardData>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.imode){
case WarpPerspectiveBackwardData::InterpolationMode::NEAREST:
props_.emplace_back("imode", "NEAREST");
break;
case WarpPerspectiveBackwardData::InterpolationMode::LINEAR:
props_.emplace_back("imode", "LINEAR");
break;
case WarpPerspectiveBackwardData::InterpolationMode::AREA:
props_.emplace_back("imode", "AREA");
break;
case WarpPerspectiveBackwardData::InterpolationMode::CUBIC:
props_.emplace_back("imode", "CUBIC");
break;
case WarpPerspectiveBackwardData::InterpolationMode::LANCZOS4:
props_.emplace_back("imode", "LANCZOS4");
break;
default:
props_.emplace_back("imode", "INVALID");
break;
}
switch (op_.bmode){
case WarpPerspectiveBackwardData::BorderMode::REPLICATE:
props_.emplace_back("bmode", "REPLICATE");
break;
case WarpPerspectiveBackwardData::BorderMode::REFLECT:
props_.emplace_back("bmode", "REFLECT");
break;
case WarpPerspectiveBackwardData::BorderMode::REFLECT_101:
props_.emplace_back("bmode", "REFLECT_101");
break;
case WarpPerspectiveBackwardData::BorderMode::WRAP:
props_.emplace_back("bmode", "WRAP");
break;
case WarpPerspectiveBackwardData::BorderMode::CONSTANT:
props_.emplace_back("bmode", "CONSTANT");
break;
case WarpPerspectiveBackwardData::BorderMode::TRANSPARENT:
props_.emplace_back("bmode", "TRANSPARENT");
break;
case WarpPerspectiveBackwardData::BorderMode::ISOLATED:
props_.emplace_back("bmode", "ISOLATED");
break;
default:
props_.emplace_back("bmode", "INVALID");
break;
}
switch (op_.format){
case WarpPerspectiveBackwardData::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case WarpPerspectiveBackwardData::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case WarpPerspectiveBackwardData::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case WarpPerspectiveBackwardData::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case WarpPerspectiveBackwardData::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case WarpPerspectiveBackwardData::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case WarpPerspectiveBackwardData::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case WarpPerspectiveBackwardData::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case WarpPerspectiveBackwardData::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case WarpPerspectiveBackwardData::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case WarpPerspectiveBackwardData::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case WarpPerspectiveBackwardData::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case WarpPerspectiveBackwardData::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case WarpPerspectiveBackwardData::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case WarpPerspectiveBackwardData::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case WarpPerspectiveBackwardData::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case WarpPerspectiveBackwardData::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case WarpPerspectiveBackwardData::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
props_.emplace_back("border_val", std::to_string(op_.border_val));
return props_;
}
std::string WarpPerspectiveBackwardData_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardData>();
static_cast<void>(op_);
return "WarpPerspectiveBackwardData";
}
} // anonymous namespace
OP_TRAIT_REG(WarpPerspectiveBackwardData, WarpPerspectiveBackwardData)
.hash(WarpPerspectiveBackwardData_hash_impl)
.is_same_st(WarpPerspectiveBackwardData_is_same_st_impl)
.props(WarpPerspectiveBackwardData_props_impl)
.make_name(WarpPerspectiveBackwardData_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardMat);
namespace {
size_t WarpPerspectiveBackwardMat_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardMat>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.imode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.bmode));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
val = mgb::hash_pair_combine(val, mgb::hash(op_.border_val));
return val;
}
bool WarpPerspectiveBackwardMat_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<WarpPerspectiveBackwardMat>(),
&&b_ = rhs_.cast_final_safe<WarpPerspectiveBackwardMat>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.imode != b_.imode) return false;
if (a_.bmode != b_.bmode) return false;
if (a_.format != b_.format) return false;
if (a_.border_val != b_.border_val) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> WarpPerspectiveBackwardMat_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardMat>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
switch (op_.imode){
case WarpPerspectiveBackwardMat::InterpolationMode::NEAREST:
props_.emplace_back("imode", "NEAREST");
break;
case WarpPerspectiveBackwardMat::InterpolationMode::LINEAR:
props_.emplace_back("imode", "LINEAR");
break;
case WarpPerspectiveBackwardMat::InterpolationMode::AREA:
props_.emplace_back("imode", "AREA");
break;
case WarpPerspectiveBackwardMat::InterpolationMode::CUBIC:
props_.emplace_back("imode", "CUBIC");
break;
case WarpPerspectiveBackwardMat::InterpolationMode::LANCZOS4:
props_.emplace_back("imode", "LANCZOS4");
break;
default:
props_.emplace_back("imode", "INVALID");
break;
}
switch (op_.bmode){
case WarpPerspectiveBackwardMat::BorderMode::REPLICATE:
props_.emplace_back("bmode", "REPLICATE");
break;
case WarpPerspectiveBackwardMat::BorderMode::REFLECT:
props_.emplace_back("bmode", "REFLECT");
break;
case WarpPerspectiveBackwardMat::BorderMode::REFLECT_101:
props_.emplace_back("bmode", "REFLECT_101");
break;
case WarpPerspectiveBackwardMat::BorderMode::WRAP:
props_.emplace_back("bmode", "WRAP");
break;
case WarpPerspectiveBackwardMat::BorderMode::CONSTANT:
props_.emplace_back("bmode", "CONSTANT");
break;
case WarpPerspectiveBackwardMat::BorderMode::TRANSPARENT:
props_.emplace_back("bmode", "TRANSPARENT");
break;
case WarpPerspectiveBackwardMat::BorderMode::ISOLATED:
props_.emplace_back("bmode", "ISOLATED");
break;
default:
props_.emplace_back("bmode", "INVALID");
break;
}
switch (op_.format){
case WarpPerspectiveBackwardMat::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case WarpPerspectiveBackwardMat::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case WarpPerspectiveBackwardMat::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case WarpPerspectiveBackwardMat::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case WarpPerspectiveBackwardMat::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case WarpPerspectiveBackwardMat::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case WarpPerspectiveBackwardMat::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case WarpPerspectiveBackwardMat::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case WarpPerspectiveBackwardMat::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case WarpPerspectiveBackwardMat::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case WarpPerspectiveBackwardMat::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case WarpPerspectiveBackwardMat::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case WarpPerspectiveBackwardMat::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case WarpPerspectiveBackwardMat::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case WarpPerspectiveBackwardMat::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case WarpPerspectiveBackwardMat::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case WarpPerspectiveBackwardMat::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case WarpPerspectiveBackwardMat::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
props_.emplace_back("border_val", std::to_string(op_.border_val));
return props_;
}
std::string WarpPerspectiveBackwardMat_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<WarpPerspectiveBackwardMat>();
static_cast<void>(op_);
return "WarpPerspectiveBackwardMat";
}
} // anonymous namespace
OP_TRAIT_REG(WarpPerspectiveBackwardMat, WarpPerspectiveBackwardMat)
.hash(WarpPerspectiveBackwardMat_hash_impl)
.is_same_st(WarpPerspectiveBackwardMat_is_same_st_impl)
.props(WarpPerspectiveBackwardMat_props_impl)
.make_name(WarpPerspectiveBackwardMat_make_name_impl);
// clang-format on
......@@ -18185,6 +18185,346 @@ void _init_py_WarpPerspective(py::module m) {
m.add_object("WarpPerspective", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspective::typeinfo(), &py_type).second);
}
void _init_py_WarpPerspectiveBackwardData_InterpolationMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardData::InterpolationMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "InterpolationMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_WarpPerspectiveBackwardData_BorderMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardData::BorderMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "BorderMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_WarpPerspectiveBackwardData_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardData::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(WarpPerspectiveBackwardData) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"imode", serialization<decltype(opdef.imode)>::dump(opdef.imode)},
{"bmode", serialization<decltype(opdef.bmode)>::dump(opdef.bmode)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"border_val", serialization<decltype(opdef.border_val)>::dump(opdef.border_val)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("imode");
if (iter != state.end()) {
opdef.imode = serialization<decltype(opdef.imode)>::load(iter->second);
}
}
{
auto&& iter = state.find("bmode");
if (iter != state.end()) {
opdef.bmode = serialization<decltype(opdef.bmode)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
{
auto&& iter = state.find("border_val");
if (iter != state.end()) {
opdef.border_val = serialization<decltype(opdef.border_val)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(WarpPerspectiveBackwardData)
int PyOp(WarpPerspectiveBackwardData)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"imode", "bmode", "format", "border_val", "scope", NULL};
PyObject *imode = NULL, *bmode = NULL, *format = NULL, *border_val = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &imode, &bmode, &format, &border_val, &scope))
return -1;
if (imode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst().imode =
py::cast<decltype(WarpPerspectiveBackwardData::imode)>(py::handle(imode));
} CATCH_ALL(-1)
}
if (bmode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst().bmode =
py::cast<decltype(WarpPerspectiveBackwardData::bmode)>(py::handle(bmode));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst().format =
py::cast<decltype(WarpPerspectiveBackwardData::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (border_val) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardData)*>(self)->inst().border_val =
py::cast<decltype(WarpPerspectiveBackwardData::border_val)>(py::handle(border_val));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(WarpPerspectiveBackwardData)::py_getsetters[] = {
{const_cast<char*>("imode"), py_get_generic(WarpPerspectiveBackwardData, imode), py_set_generic(WarpPerspectiveBackwardData, imode), const_cast<char*>("imode"), NULL},
{const_cast<char*>("bmode"), py_get_generic(WarpPerspectiveBackwardData, bmode), py_set_generic(WarpPerspectiveBackwardData, bmode), const_cast<char*>("bmode"), NULL},
{const_cast<char*>("format"), py_get_generic(WarpPerspectiveBackwardData, format), py_set_generic(WarpPerspectiveBackwardData, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("border_val"), py_get_generic(WarpPerspectiveBackwardData, border_val), py_set_generic(WarpPerspectiveBackwardData, border_val), const_cast<char*>("border_val"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(WarpPerspectiveBackwardData)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(WarpPerspectiveBackwardData)::getstate, METH_NOARGS, "WarpPerspectiveBackwardData getstate"},
{const_cast<char*>("__setstate__"), PyOp(WarpPerspectiveBackwardData)::setstate, METH_VARARGS, "WarpPerspectiveBackwardData setstate"},
{NULL} /* Sentinel */
};
void _init_py_WarpPerspectiveBackwardData(py::module m) {
using py_op = PyOp(WarpPerspectiveBackwardData);
auto& py_type = PyOpType(WarpPerspectiveBackwardData);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.WarpPerspectiveBackwardData";
py_type.tp_basicsize = sizeof(PyOp(WarpPerspectiveBackwardData));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "WarpPerspectiveBackwardData";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_WarpPerspectiveBackwardData_InterpolationMode(py_type);
_init_py_WarpPerspectiveBackwardData_BorderMode(py_type);
_init_py_WarpPerspectiveBackwardData_Format(py_type);
PyType_Modified(&py_type);
m.add_object("WarpPerspectiveBackwardData", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardData::typeinfo(), &py_type).second);
}
void _init_py_WarpPerspectiveBackwardMat_InterpolationMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardMat::InterpolationMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "InterpolationMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_WarpPerspectiveBackwardMat_BorderMode(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardMat::BorderMode>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "BorderMode", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
void _init_py_WarpPerspectiveBackwardMat_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<WarpPerspectiveBackwardMat::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(WarpPerspectiveBackwardMat) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"imode", serialization<decltype(opdef.imode)>::dump(opdef.imode)},
{"bmode", serialization<decltype(opdef.bmode)>::dump(opdef.bmode)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)},
{"border_val", serialization<decltype(opdef.border_val)>::dump(opdef.border_val)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("imode");
if (iter != state.end()) {
opdef.imode = serialization<decltype(opdef.imode)>::load(iter->second);
}
}
{
auto&& iter = state.find("bmode");
if (iter != state.end()) {
opdef.bmode = serialization<decltype(opdef.bmode)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
{
auto&& iter = state.find("border_val");
if (iter != state.end()) {
opdef.border_val = serialization<decltype(opdef.border_val)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(WarpPerspectiveBackwardMat)
int PyOp(WarpPerspectiveBackwardMat)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"imode", "bmode", "format", "border_val", "scope", NULL};
PyObject *imode = NULL, *bmode = NULL, *format = NULL, *border_val = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &imode, &bmode, &format, &border_val, &scope))
return -1;
if (imode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst().imode =
py::cast<decltype(WarpPerspectiveBackwardMat::imode)>(py::handle(imode));
} CATCH_ALL(-1)
}
if (bmode) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst().bmode =
py::cast<decltype(WarpPerspectiveBackwardMat::bmode)>(py::handle(bmode));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst().format =
py::cast<decltype(WarpPerspectiveBackwardMat::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (border_val) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(WarpPerspectiveBackwardMat)*>(self)->inst().border_val =
py::cast<decltype(WarpPerspectiveBackwardMat::border_val)>(py::handle(border_val));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(WarpPerspectiveBackwardMat)::py_getsetters[] = {
{const_cast<char*>("imode"), py_get_generic(WarpPerspectiveBackwardMat, imode), py_set_generic(WarpPerspectiveBackwardMat, imode), const_cast<char*>("imode"), NULL},
{const_cast<char*>("bmode"), py_get_generic(WarpPerspectiveBackwardMat, bmode), py_set_generic(WarpPerspectiveBackwardMat, bmode), const_cast<char*>("bmode"), NULL},
{const_cast<char*>("format"), py_get_generic(WarpPerspectiveBackwardMat, format), py_set_generic(WarpPerspectiveBackwardMat, format), const_cast<char*>("format"), NULL},
{const_cast<char*>("border_val"), py_get_generic(WarpPerspectiveBackwardMat, border_val), py_set_generic(WarpPerspectiveBackwardMat, border_val), const_cast<char*>("border_val"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(WarpPerspectiveBackwardMat)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(WarpPerspectiveBackwardMat)::getstate, METH_NOARGS, "WarpPerspectiveBackwardMat getstate"},
{const_cast<char*>("__setstate__"), PyOp(WarpPerspectiveBackwardMat)::setstate, METH_VARARGS, "WarpPerspectiveBackwardMat setstate"},
{NULL} /* Sentinel */
};
void _init_py_WarpPerspectiveBackwardMat(py::module m) {
using py_op = PyOp(WarpPerspectiveBackwardMat);
auto& py_type = PyOpType(WarpPerspectiveBackwardMat);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.WarpPerspectiveBackwardMat";
py_type.tp_basicsize = sizeof(PyOp(WarpPerspectiveBackwardMat));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "WarpPerspectiveBackwardMat";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_WarpPerspectiveBackwardMat_InterpolationMode(py_type);
_init_py_WarpPerspectiveBackwardMat_BorderMode(py_type);
_init_py_WarpPerspectiveBackwardMat_Format(py_type);
PyType_Modified(&py_type);
m.add_object("WarpPerspectiveBackwardMat", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(WarpPerspectiveBackwardMat::typeinfo(), &py_type).second);
}
#define INIT_ALL_OP(m) \
_init_py_AdaptivePooling(m); \
_init_py_AddAxis(m); \
......@@ -18290,5 +18630,7 @@ void _init_py_WarpPerspective(py::module m) {
_init_py_TypeCvt(m); \
_init_py_UniformRNG(m); \
_init_py_WarpAffine(m); \
_init_py_WarpPerspective(m);
_init_py_WarpPerspective(m); \
_init_py_WarpPerspectiveBackwardData(m); \
_init_py_WarpPerspectiveBackwardMat(m);
// clang-format on
......@@ -1795,4 +1795,42 @@ public:
}
};
class WarpPerspectiveBackwardData : public OpDefImplBase<WarpPerspectiveBackwardData> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
using Format = ::megdnn::param::WarpPerspective::Format;
InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
float border_val = .0f;
WarpPerspectiveBackwardData() = default;
WarpPerspectiveBackwardData(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
WarpPerspectiveBackwardData(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
::megdnn::param::WarpPerspective param() const {
return {imode, bmode, format, border_val};
}
};
class WarpPerspectiveBackwardMat : public OpDefImplBase<WarpPerspectiveBackwardMat> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
using Format = ::megdnn::param::WarpPerspective::Format;
InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
float border_val = .0f;
WarpPerspectiveBackwardMat() = default;
WarpPerspectiveBackwardMat(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
WarpPerspectiveBackwardMat(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
::megdnn::param::WarpPerspective param() const {
return {imode, bmode, format, border_val};
}
};
// clang-format on
......@@ -1858,4 +1858,34 @@ WarpPerspectiveInst
.def_readwrite("format", &WarpPerspective::format)
.def_readwrite("border_val", &WarpPerspective::border_val);
py::class_<WarpPerspectiveBackwardData, std::shared_ptr<WarpPerspectiveBackwardData>, OpDef> WarpPerspectiveBackwardDataInst(m, "WarpPerspectiveBackwardData");
WarpPerspectiveBackwardDataInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
WarpPerspectiveBackwardDataInst.attr("BorderMode") = RemapInst.attr("BorderMode");
WarpPerspectiveBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
WarpPerspectiveBackwardDataInst
.def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
.def_readwrite("imode", &WarpPerspectiveBackwardData::imode)
.def_readwrite("bmode", &WarpPerspectiveBackwardData::bmode)
.def_readwrite("format", &WarpPerspectiveBackwardData::format)
.def_readwrite("border_val", &WarpPerspectiveBackwardData::border_val);
py::class_<WarpPerspectiveBackwardMat, std::shared_ptr<WarpPerspectiveBackwardMat>, OpDef> WarpPerspectiveBackwardMatInst(m, "WarpPerspectiveBackwardMat");
WarpPerspectiveBackwardMatInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
WarpPerspectiveBackwardMatInst.attr("BorderMode") = RemapInst.attr("BorderMode");
WarpPerspectiveBackwardMatInst.attr("Format") = AdaptivePoolingInst.attr("Format");
WarpPerspectiveBackwardMatInst
.def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
.def_readwrite("imode", &WarpPerspectiveBackwardMat::imode)
.def_readwrite("bmode", &WarpPerspectiveBackwardMat::bmode)
.def_readwrite("format", &WarpPerspectiveBackwardMat::format)
.def_readwrite("border_val", &WarpPerspectiveBackwardMat::border_val);
// clang-format on
......@@ -104,6 +104,10 @@ def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>;
def WarpPerspectiveBackwardData: MgbHashableOp<"WarpPerspectiveBackwardData", [WarpPerspectiveParam]>;
def WarpPerspectiveBackwardMat: MgbHashableOp<"WarpPerspectiveBackwardMat", [WarpPerspectiveParam]>;
def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册