From 2ac3c9dc8bd3bc0636ebb00fdfd26cd120b1152c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 15 Jun 2021 19:18:00 +0800 Subject: [PATCH] fix(trace): constants in backward graph treat as ImmutableTensor corectlly GitOrigin-RevId: 5fd6a5e00c5b5ddeb19cedde62f2739e76c7bf41 --- imperative/python/src/tensor.cpp | 21 +++++++++++++++++++++ imperative/python/src/tensor.h | 9 ++++----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 7c6c9ddd4..28357b226 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -13,6 +13,7 @@ #include "megbrain/common.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/backward_graph.h" +#include "megbrain/opr/io.h" #include "./tensor.h" #include "./grad.h" @@ -39,6 +40,26 @@ interpreter::Interpreter::Channel* interpreter_for_py; PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; PyObject *cpp_apply_backward_varnode; +std::shared_ptr make_const(imperative::TensorPtr value) { + if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { + return std::make_shared(interpreter_for_py->put(value->dev_tensor())); + } + py::tuple tup(6); + auto data = value->get_value(); + tup[0] = py::reinterpret_steal(ndarray_from_tensor(data, npy::ShareType::MUST_SHARE)); + tup[1] = value->dtype(); + tup[2] = value->comp_node(); + tup[3] = true; + tup[4] = false; + tup[5] = py::none{}; + auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); + if (!py_ret) throw py::error_already_set(); + auto py_list = py::reinterpret_steal(py_ret); + auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr()); + auto tensor = tensor_wrapper->m_tensor; + return tensor_wrapper->m_tensor; +} + #define REGISTE_APPLY_FUNC(mode) \ void set_##mode(py::object pyf) { \ mode = pyf.ptr(); \ diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index be6fd1839..f9409a3f2 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -271,18 +271,17 @@ auto apply(std::shared_ptr op, T&& tensors) return apply(op, args, nargs); } +std::shared_ptr make_const(imperative::TensorPtr value); + inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { SmallVector> inputs; for (size_t i = 0; i < nargs; ++i) { inputs.push_back(args[i]->shared_from_this()); } auto apply_functor = [](std::shared_ptr op, SmallVector> inputs) { - return apply(op, inputs); - }; - auto const_functor = [](imperative::TensorPtr value) { - return std::make_shared(interpreter_for_py->put(value->dev_tensor())); + return apply(op, std::move(inputs)); }; - return graph.apply(inputs, apply_functor, const_functor); + return graph.apply(inputs, apply_functor, &make_const); } template -- GitLab