提交 81d8c73a 编写于 作者: M Megvii Engine Team

perf(dispatch/trace): serval tricks to speed up trace

GitOrigin-RevId: 2bdd70cde2d19f43055218804abd65f4e5a54b89
上级 4fa61620
...@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump( ...@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump(
auto& node = nodes[input]; auto& node = nodes[input];
// TODO: cambricon CompNode // TODO: cambricon CompNode
auto host = std::make_shared<HostTensorND>( auto host = std::make_shared<HostTensorND>(
CompNode::load("xpux"), shape, var.dtype); CompNode::load("xpux"), shape, *var.dtype);
OperatorNodeConfig config; OperatorNodeConfig config;
// if prefer_input_names, prefer names from dump args // if prefer_input_names, prefer names from dump args
// else prefer names got from trace procedure // else prefer names got from trace procedure
...@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation(
auto& var_info = m_vars[tracing_value->id()]; auto& var_info = m_vars[tracing_value->id()];
switch (get_attr->attr()) { switch (get_attr->attr()) {
case GetAttr::Shape: case GetAttr::Shape:
// TODO: reduce h2d when data or value is available
var_info.shape_required = true; var_info.shape_required = true;
break; break;
case GetAttr::Data: case GetAttr::Data:
...@@ -301,8 +300,8 @@ void CompiledTransformation::compile() { ...@@ -301,8 +300,8 @@ void CompiledTransformation::compile() {
auto box = make_box<DeviceTensorND>(); auto box = make_box<DeviceTensorND>();
// TODO: attach ref count, release early // TODO: attach ref count, release early
auto outputs = opr::InputCallback::make( auto outputs = opr::InputCallback::make(
*m_graph, [box] { return box->take_value(); }, var_info->device, *m_graph, [box] { return box->take_value(); }, *var_info->device,
var_info->dtype, var_info->shape, io_links, m_input_shape_static); *var_info->dtype, var_info->shape, io_links, m_input_shape_static);
// attach input_callback to io_links // attach input_callback to io_links
accessor.node = outputs[0].node(); accessor.node = outputs[0].node();
io_links = {outputs[1]}; io_links = {outputs[1]};
...@@ -312,6 +311,11 @@ void CompiledTransformation::compile() { ...@@ -312,6 +311,11 @@ void CompiledTransformation::compile() {
auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) {
VarAccessor accessor; VarAccessor accessor;
accessor.node = node.node(); accessor.node = node.node();
if (var_info->data_required) {
// reduce d2h when data is available
// FIXME: compile should not change var_info in-place
var_info->shape_required = false;
}
if (var_info->shape_required) { if (var_info->shape_required) {
// TODO: use static infer manager for some vars? // TODO: use static infer manager for some vars?
auto box = make_box<TensorShape>(); auto box = make_box<TensorShape>();
...@@ -334,6 +338,12 @@ void CompiledTransformation::compile() { ...@@ -334,6 +338,12 @@ void CompiledTransformation::compile() {
accessor.data_getter = [box]() -> DeviceTensorND { accessor.data_getter = [box]() -> DeviceTensorND {
return box->get_value(); return box->get_value();
}; };
if (!accessor.shape_getter) {
// also implement shape_getter
accessor.shape_getter = [box]() -> TensorShape {
return box->get_value().shape();
};
}
} }
if (var_info->value_required) { if (var_info->value_required) {
struct ValueWithEvent { struct ValueWithEvent {
...@@ -341,7 +351,7 @@ void CompiledTransformation::compile() { ...@@ -341,7 +351,7 @@ void CompiledTransformation::compile() {
CompNode::Event* event = nullptr; CompNode::Event* event = nullptr;
}; };
auto box = make_box<ValueWithEvent>(); auto box = make_box<ValueWithEvent>();
auto event = EventPool::without_timer().alloc_shared(var_info->device); auto event = EventPool::without_timer().alloc_shared(*var_info->device);
auto callback = [box, event](DeviceTensorND data) { auto callback = [box, event](DeviceTensorND data) {
HostTensorND host_val; HostTensorND host_val;
host_val.copy_from(data); host_val.copy_from(data);
...@@ -355,7 +365,7 @@ void CompiledTransformation::compile() { ...@@ -355,7 +365,7 @@ void CompiledTransformation::compile() {
}; };
SymbolVarArray inputs = io_links; SymbolVarArray inputs = io_links;
inputs.insert(inputs.begin(), node); inputs.insert(inputs.begin(), node);
auto output = opr::OutputCallback::make({callback, false, true}, inputs); auto output = opr::OutputCallback::make({callback, true, true}, inputs);
io_links = {output}; io_links = {output};
accessor.value_getter = [box]() -> HostTensorND { accessor.value_getter = [box]() -> HostTensorND {
auto&& [value, event] = box->get_value(); auto&& [value, event] = box->get_value();
...@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { ...@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
DType dtype = *value.dtype(); DType dtype = *value.dtype();
CompNode device = *value.device(); CompNode device = *value.device();
trace_assert( trace_assert(
var.dtype == dtype, "dtype mismatch: %s vs %s", *var.dtype == dtype, "dtype mismatch: %s vs %s",
var.dtype.name(), dtype.name()); var.dtype->name(), dtype.name());
trace_assert( trace_assert(
var.device == device, "comp_node mismatch: %s vs %s", *var.device == device, "comp_node mismatch: %s vs %s",
var.device.to_string().c_str(), device.to_string().c_str()); var.device->to_string().c_str(),
device.to_string().c_str());
} }
var_accessor.data_setter(value.dev_tensor()->as_nd()); var_accessor.data_setter(value.dev_tensor()->as_nd());
break; break;
...@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { ...@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
} }
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const {
if (!m_dtype) { return m_var->dtype;
m_dtype = DTypeValue::make(m_var->dtype);
}
return m_dtype;
} }
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const {
if (!m_comp_node) { return m_var->device;
m_comp_node = CompNodeValue::make(m_var->device);
}
return m_comp_node;
} }
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& {
return *m_accessor; return *m_accessor;
......
...@@ -44,8 +44,8 @@ struct TraceResult { ...@@ -44,8 +44,8 @@ struct TraceResult {
}; };
size_t id; size_t id;
DType dtype; DTypeValue::ref_t dtype;
CompNode device; CompNodeValue::ref_t device;
// if exists, assert equal when meet // if exists, assert equal when meet
ValueRef bound_data; ValueRef bound_data;
...@@ -162,7 +162,7 @@ public: ...@@ -162,7 +162,7 @@ public:
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) { TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
size_t id = m_vars.size(); size_t id = m_vars.size();
auto wrapped_value = TracingValue::make(value, id); auto wrapped_value = TracingValue::make(value, id);
m_vars.push_back({id, *value.dtype(), *value.device()}); m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back(); auto& var = m_vars.back();
if (capture) { if (capture) {
var.bound_data = value; var.bound_data = value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册