提交 3246ee5e 编写于 作者: M Megvii Engine Team

perf(mge): use DepType::HOST_VALUE in trace when possible

GitOrigin-RevId: 5d47ed263fe0c5d65f86d53d61bd1c427139c06d
上级 0e303710
......@@ -483,13 +483,13 @@ void init_graph_rt(py::module m) {
},
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none());
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) {
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) {
SymbolVarArray sinputs;
for (auto i : inputs) {
sinputs.emplace_back(i);
}
static_assert(!std::is_reference<decltype(callback)>::value);
opr::OutputCallback::Param param{std::move(callback), borrow};
opr::OutputCallback::Param param{std::move(callback), borrow, prefer_host_value};
auto output = opr::OutputCallback::make(std::move(param), sinputs);
return output.node();
};
......@@ -519,7 +519,7 @@ void init_graph_rt(py::module m) {
hv_with_event.second->record();
p->set(std::move(hv_with_event));
};
return output_callback(std::move(f), std::move(inputs), true);
return output_callback(std::move(f), std::move(inputs), true, true);
});
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
......
......@@ -144,13 +144,24 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
SmallVector<NodeProp::DepType> dep_types(input().size(),
NodeProp::DepType::DEV_COMP_ORDER);
dep_types[0] = NodeProp::DepType::DEV_VALUE;
using IT = cg::static_infer::InferType;
auto host_value_avail = [&]() -> bool {
auto inp = input(0);
auto it = owner_graph()->static_infer_manager().get_infer_type(inp).value;
return it & (IT::CONST | IT::RT_STATIC | IT::MISSING_INP);
};
m_use_host_value = m_param.prefer_host_value && host_value_avail();
dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE : NodeProp::DepType::DEV_VALUE;
prop->reset_dep_type(input(), dep_types);
return prop;
}
void OutputCallback::scn_do_execute() {
m_param.callback(input(0)->dev_tensor());
if (m_use_host_value) {
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0)));
} else {
m_param.callback(input(0)->dev_tensor());
}
}
cg::OperatorNodeBase* OutputCallback::shallow_copy(
......
......@@ -60,7 +60,8 @@ public:
using callback_t = thin_function<void(DeviceTensorND)>;
struct Param {
callback_t callback;
bool borrow = false;
bool borrow = false; // do not obtain shared ownership on DeviceTensorND
bool prefer_host_value = false; // use host value when possible
};
OutputCallback(Param param,
const VarNodeArray& inputs,
......@@ -81,6 +82,7 @@ protected:
NodeProp* do_make_node_prop() const override;
private:
Param m_param;
mutable bool m_use_host_value;
};
MGB_DEFINE_OPR_CLASS(NopCallback, cg::OperatorNodeBase) // {
......
......@@ -13,6 +13,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/test/helper.h"
using namespace mgb;
......@@ -50,6 +51,27 @@ TEST(TestOprUtility, OutputCallback) {
MGB_ASSERT_TENSOR_EQ(hy, *hx);
}
TEST(TestOprUtility, OutputCallbackPreferHost) {
HostTensorGenerator<> gen;
auto hx = gen({2, 3});
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, hx);
x = opr::GetVarShape::make(x);
HostTensorND hy;
auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);};
opr::OutputCallback::Param param{callback};
param.prefer_host_value = true;
auto dummy = opr::OutputCallback::make(param, x);
auto y = opr::VirtualDep::make({x, dummy});
ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}};
auto func = graph->compile(outspec);
func->execute();
ASSERT_TRUE(hy.comp_node() == CompNode::default_cpu());
ASSERT_EQ(hy.ptr<int>()[0], 2);
ASSERT_EQ(hy.ptr<int>()[1], 3);
}
TEST(TestOprUtility, NopCallback) {
HostTensorGenerator<> gen;
auto hx = gen({2, 3});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册