From e9e5f442a7acb47acacae45644f8f46eed35cfa8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 31 Dec 2020 17:45:12 +0800 Subject: [PATCH] fix(mge): expand custom op before trace GitOrigin-RevId: 725a5b87cb7deb62c8ef1f4f919278b57b57ffb7 --- imperative/python/src/tensor.cpp | 41 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 396ed644a..1b5bc4157 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -96,29 +96,30 @@ apply_result_t apply(ApplyContext& ctx) { return apply_grad(ctx); } + if (auto* op = ctx.op->try_cast_final()) { + py::tuple pyin(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; ++i) { + pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); + } + auto f = py::getattr(op->obj, "_default_rule"); + auto pyout = py::reinterpret_steal(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); + if (!pyout) throw py::error_already_set(); + if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { + return {tw->m_tensor}; + } + apply_result_t ret; + ret.reserve(py::len(pyout)); + for (auto&& i : pyout) { + auto* tw = TensorWrapper::try_cast(i.ptr()); + mgb_assert(tw); + ret.push_back(tw->m_tensor); + } + return ret; + } + if (flags & Tensor::Flags::TRACE) { return apply_trace(ctx); } else { - if (auto* op = ctx.op->try_cast_final()) { - py::tuple pyin(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); - } - auto f = py::getattr(op->obj, "_default_rule"); - auto pyout = py::reinterpret_steal(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); - if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { - return {tw->m_tensor}; - } - apply_result_t ret; - ret.reserve(py::len(pyout)); - for (auto&& i : pyout) { - auto* tw = TensorWrapper::try_cast(i.ptr()); - mgb_assert(tw); - ret.push_back(tw->m_tensor); - } - return ret; - } - SmallVector handles(ctx.nargs); for (size_t i = 0; i < ctx.nargs; ++i) { handles[i] = ctx.args[i]->m_handle.get(); -- GitLab