提交 150a6a61 编写于 作者: M Megvii Engine Team

perf(dispatch/trace): remove unnecessary h2d for constant

GitOrigin-RevId: d00de3fc1fa3f67485b3b147652e45af82b01c7a
上级 81d8c73a
......@@ -1399,12 +1399,12 @@ void init_tensor(py::module m) {
std::function<bool(py::object, py::object)> array_comparator;
bool compare_value(ValueRef lhs, ValueRef rhs) {
auto lvalue = lhs.numpy();
auto rvalue = rhs.numpy();
auto lvalue = lhs.cast_ref<HostValue>();
auto rvalue = rhs.cast_ref<HostValue>();
if (lvalue->shape() != rvalue->shape()) {
return false;
}
if (lvalue->shape().is_scalar()) {
if (lvalue->shape().total_nr_elems() == 1) {
return lvalue->item() == rvalue->item();
}
HostTensorND lnd = lvalue->as_nd(true);
......
......@@ -50,9 +50,10 @@ ValueRefList LazyEvalTransformation::apply_transformation(
}
if (require_link && m_io_link.node()) {
mgb_assert(!input_nodes.empty());
input_nodes[0] =
opr::VirtualDep::make({SymbolVar(input_nodes[0]), m_io_link})
.node();
auto comp_node = m_io_link.node()->comp_node();
input_nodes[0] = opr::VirtualDep::make(
{SymbolVar(input_nodes[0]), m_io_link}, comp_node)
.node();
}
VarNodeArray output_nodes = OpDef::apply_on_var_node(op_val->op(), input_nodes);
if (require_link) {
......
......@@ -196,10 +196,11 @@ ValueRefList TracingTransformation::apply_transformation(
return outputs;
}
bool is_const = create_tensor->kind() == CreateTensor::Const;
bool as_const = is_const || m_capture_as_const;
auto wrapped_input = record_var(
outputs[0], is_const || m_capture_as_const,
is_const ? VarKind::Constant : VarKind::External);
auto wrapped_output = record_var(outputs[0], false, VarKind::Internal);
outputs[0], as_const, is_const ? VarKind::Constant : VarKind::External);
// bound data to outputs too to reduce runtime overhead for shape/value infer
auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal);
auto input_id = wrapped_input->id();
auto output_id = wrapped_output->id();
m_seq.push_back({{}, {input_id}, {output_id}});
......@@ -311,6 +312,18 @@ void CompiledTransformation::compile() {
auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) {
VarAccessor accessor;
accessor.node = node.node();
if (auto bound_data = var_info->bound_data) {
accessor.shape_getter = [bound_data]() -> TensorShape {
return bound_data.shape()->as_tensor_shape();
};
accessor.data_getter = [bound_data]() -> DeviceTensorND {
return bound_data.dev_tensor()->as_nd();
};
accessor.value_getter = [bound_data]() -> HostTensorND {
return bound_data.numpy()->as_nd();
};
return accessor;
}
if (var_info->data_required) {
// reduce d2h when data is available
// FIXME: compile should not change var_info in-place
......@@ -410,16 +423,28 @@ void CompiledTransformation::compile() {
"internal node should be valid when used as input");
}
}
input_vars.push_back(var_accessors[input].node);
auto& node = var_accessors[input].node;
if (input_vars.empty() && require_link && mm_io_link.node()) {
/*mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");*/
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node)
.node();
}
input_vars.push_back(node);
}
if (require_link && mm_io_link.node()) {
/*if (require_link && mm_io_link.node()) {
mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");
input_vars[0] =
opr::VirtualDep::make({SymbolVar(input_vars[0]), mm_io_link})
.node();
}
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
input_vars[0] = opr::VirtualDep::make(
{SymbolVar(input_vars[0]), mm_io_link}, comp_node)
.node();
}*/
VarNodeArray output_vars;
if (item.op) {
output_vars = OpDef::apply_on_var_node(*item.op, input_vars);
......@@ -479,6 +504,12 @@ void CompiledTransformation::recompile() {
}
void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) {
if (!lhs.is<HostValue>()) {
lhs = lhs.numpy();
}
if (!rhs.is<HostValue>()) {
rhs = rhs.numpy();
}
trace_assert(m_value_comparator(lhs, rhs), "tensors not equals");
}
......@@ -507,6 +538,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
break;
}
case VarKind::Constant: {
// expect host value here
mgb_assert(var.bound_data, "const var without data bound");
assert_tensor_equal(var.bound_data, value);
break;
......@@ -611,7 +643,17 @@ ValueRefList CompiledTransformation::apply_create_tensor(
trace_assert(item.op == nullptr, "operator mismatch");
auto input_id = item.inputs[0];
auto output_id = item.outputs[0];
auto tensor = imperative::apply(create_tensor, inputs)[0];
ValueRef tensor;
if (create_tensor.kind() == CreateTensor::Const) {
auto args = create_tensor.parse(inputs);
if (args.host) {
// performance issue
tensor = HostValue::make(*args.host);
}
}
if (!tensor) {
tensor = imperative::apply(create_tensor, inputs)[0];
}
trace_input(input_id, tensor);
return {trace_output(output_id)};
}
......
......@@ -103,7 +103,8 @@ public:
CompNode device() const { return m_storage.comp_node(); }
const HostTensorStorage& storage() const { return m_storage; }
DTypeScalar item() const {
mgb_assert(m_shape.is_scalar());
// FIXME: check scalar
mgb_assert(m_shape.total_nr_elems());
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr());
}
......
......@@ -47,7 +47,8 @@ struct TraceResult {
DTypeValue::ref_t dtype;
CompNodeValue::ref_t device;
// if exists, assert equal when meet
// if exists, for input: assert equal
// for output: get_data/shape/value
ValueRef bound_data;
std::string mark;
std::string name;
......
......@@ -49,6 +49,7 @@ struct ValueShape {
size_t total_nr_elems() const {
size_t prod = 1;
mgb_assert(ndim >= 0 && ndim < 8);
for (int i = 0; i < ndim; ++i) {
prod *= shape[i];
}
......@@ -103,4 +104,4 @@ static_assert(sizeof(size_t) >= sizeof(int));
static_assert(TensorShape::MAX_NDIM == 7);
static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8);
} // namespace mgb::imperative
\ No newline at end of file
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册