提交 2ac3c9dc 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(trace): constants in backward graph treat as ImmutableTensor corectlly

GitOrigin-RevId: 5fd6a5e00c5b5ddeb19cedde62f2739e76c7bf41
上级 7eea1fc6
......@@ -13,6 +13,7 @@
#include "megbrain/common.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/opr/io.h"
#include "./tensor.h"
#include "./grad.h"
......@@ -39,6 +40,26 @@ interpreter::Interpreter::Channel* interpreter_for_py;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
PyObject *cpp_apply_backward_varnode;
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
}
py::tuple tup(6);
auto data = value->get_value();
tup[0] = py::reinterpret_steal<py::array>(ndarray_from_tensor(data, npy::ShareType::MUST_SHARE));
tup[1] = value->dtype();
tup[2] = value->comp_node();
tup[3] = true;
tup[4] = false;
tup[5] = py::none{};
auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret);
auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr());
auto tensor = tensor_wrapper->m_tensor;
return tensor_wrapper->m_tensor;
}
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \
mode = pyf.ptr(); \
......
......@@ -271,18 +271,17 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
return apply(op, args, nargs);
}
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);
inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) {
return apply(op, inputs);
};
auto const_functor = [](imperative::TensorPtr value) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
return apply(op, std::move(inputs));
};
return graph.apply(inputs, apply_functor, const_functor);
return graph.apply(inputs, apply_functor, &make_const);
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册