提交 81d8c73a 编写于 作者: M Megvii Engine Team

perf(dispatch/trace): serval tricks to speed up trace

GitOrigin-RevId: 2bdd70cde2d19f43055218804abd65f4e5a54b89
上级 4fa61620
......@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump(
auto& node = nodes[input];
// TODO: cambricon CompNode
auto host = std::make_shared<HostTensorND>(
CompNode::load("xpux"), shape, var.dtype);
CompNode::load("xpux"), shape, *var.dtype);
OperatorNodeConfig config;
// if prefer_input_names, prefer names from dump args
// else prefer names got from trace procedure
......@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation(
auto& var_info = m_vars[tracing_value->id()];
switch (get_attr->attr()) {
case GetAttr::Shape:
// TODO: reduce h2d when data or value is available
var_info.shape_required = true;
break;
case GetAttr::Data:
......@@ -301,8 +300,8 @@ void CompiledTransformation::compile() {
auto box = make_box<DeviceTensorND>();
// TODO: attach ref count, release early
auto outputs = opr::InputCallback::make(
*m_graph, [box] { return box->take_value(); }, var_info->device,
var_info->dtype, var_info->shape, io_links, m_input_shape_static);
*m_graph, [box] { return box->take_value(); }, *var_info->device,
*var_info->dtype, var_info->shape, io_links, m_input_shape_static);
// attach input_callback to io_links
accessor.node = outputs[0].node();
io_links = {outputs[1]};
......@@ -312,6 +311,11 @@ void CompiledTransformation::compile() {
auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) {
VarAccessor accessor;
accessor.node = node.node();
if (var_info->data_required) {
// reduce d2h when data is available
// FIXME: compile should not change var_info in-place
var_info->shape_required = false;
}
if (var_info->shape_required) {
// TODO: use static infer manager for some vars?
auto box = make_box<TensorShape>();
......@@ -334,6 +338,12 @@ void CompiledTransformation::compile() {
accessor.data_getter = [box]() -> DeviceTensorND {
return box->get_value();
};
if (!accessor.shape_getter) {
// also implement shape_getter
accessor.shape_getter = [box]() -> TensorShape {
return box->get_value().shape();
};
}
}
if (var_info->value_required) {
struct ValueWithEvent {
......@@ -341,7 +351,7 @@ void CompiledTransformation::compile() {
CompNode::Event* event = nullptr;
};
auto box = make_box<ValueWithEvent>();
auto event = EventPool::without_timer().alloc_shared(var_info->device);
auto event = EventPool::without_timer().alloc_shared(*var_info->device);
auto callback = [box, event](DeviceTensorND data) {
HostTensorND host_val;
host_val.copy_from(data);
......@@ -355,7 +365,7 @@ void CompiledTransformation::compile() {
};
SymbolVarArray inputs = io_links;
inputs.insert(inputs.begin(), node);
auto output = opr::OutputCallback::make({callback, false, true}, inputs);
auto output = opr::OutputCallback::make({callback, true, true}, inputs);
io_links = {output};
accessor.value_getter = [box]() -> HostTensorND {
auto&& [value, event] = box->get_value();
......@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
DType dtype = *value.dtype();
CompNode device = *value.device();
trace_assert(
var.dtype == dtype, "dtype mismatch: %s vs %s",
var.dtype.name(), dtype.name());
*var.dtype == dtype, "dtype mismatch: %s vs %s",
var.dtype->name(), dtype.name());
trace_assert(
var.device == device, "comp_node mismatch: %s vs %s",
var.device.to_string().c_str(), device.to_string().c_str());
*var.device == device, "comp_node mismatch: %s vs %s",
var.device->to_string().c_str(),
device.to_string().c_str());
}
var_accessor.data_setter(value.dev_tensor()->as_nd());
break;
......@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
}
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const {
if (!m_dtype) {
m_dtype = DTypeValue::make(m_var->dtype);
}
return m_dtype;
return m_var->dtype;
}
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const {
if (!m_comp_node) {
m_comp_node = CompNodeValue::make(m_var->device);
}
return m_comp_node;
return m_var->device;
}
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& {
return *m_accessor;
......
......@@ -44,8 +44,8 @@ struct TraceResult {
};
size_t id;
DType dtype;
CompNode device;
DTypeValue::ref_t dtype;
CompNodeValue::ref_t device;
// if exists, assert equal when meet
ValueRef bound_data;
......@@ -162,7 +162,7 @@ public:
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
size_t id = m_vars.size();
auto wrapped_value = TracingValue::make(value, id);
m_vars.push_back({id, *value.dtype(), *value.device()});
m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back();
if (capture) {
var.bound_data = value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册