提交 e9e5f442 编写于 作者: M Megvii Engine Team

fix(mge): expand custom op before trace

GitOrigin-RevId: 725a5b87cb7deb62c8ef1f4f919278b57b57ffb7
上级 3faba54f
...@@ -96,9 +96,6 @@ apply_result_t apply(ApplyContext& ctx) { ...@@ -96,9 +96,6 @@ apply_result_t apply(ApplyContext& ctx) {
return apply_grad(ctx); return apply_grad(ctx);
} }
if (flags & Tensor::Flags::TRACE) {
return apply_trace(ctx);
} else {
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
py::tuple pyin(ctx.nargs); py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
...@@ -106,6 +103,7 @@ apply_result_t apply(ApplyContext& ctx) { ...@@ -106,6 +103,7 @@ apply_result_t apply(ApplyContext& ctx) {
} }
auto f = py::getattr(op->obj, "_default_rule"); auto f = py::getattr(op->obj, "_default_rule");
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
if (!pyout) throw py::error_already_set();
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
return {tw->m_tensor}; return {tw->m_tensor};
} }
...@@ -119,6 +117,9 @@ apply_result_t apply(ApplyContext& ctx) { ...@@ -119,6 +117,9 @@ apply_result_t apply(ApplyContext& ctx) {
return ret; return ret;
} }
if (flags & Tensor::Flags::TRACE) {
return apply_trace(ctx);
} else {
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
handles[i] = ctx.args[i]->m_handle.get(); handles[i] = ctx.args[i]->m_handle.get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册