提交 cbb47089 编写于 作者: M Megvii Engine Team

perf(interpreter): add fastpath for GetVarShape

GitOrigin-RevId: d1ac4e7fe38c8b8dc4d5141603ebea545a3c396f
上级 b4581788
......@@ -269,7 +269,20 @@ void ChannelImpl::dispatch_default_cpu(
uint64_t op_id = Profiler::next_id();
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
if (op->trait()->apply_on_device_tensornd) {
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
} else {
// proxy to apply_on_physical_tensor
SmallVector<TensorPtr> input_tensors;
for (auto&& input_tensornd : input_tensornds) {
input_tensors.push_back(Tensor::make(
input_tensornd, HostTensorND::make_proxy(input_tensornd)));
}
auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors);
for (size_t i = 0; i < output_tensors.size(); ++i) {
output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
}
}
SmallVector<TensorInfo*> output_infos;
output_infos.reserve(output_descs.size());
......@@ -357,6 +370,17 @@ SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
size_t ndim = input->desc.layout.ndim;
auto& gvs = op->cast_final_safe<GetVarShape>();
if (gvs.axis == MEGDNN_MAX_NDIM) {
HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
}
}
return apply_op_impl(std::move(op), inputs);
}
......@@ -621,6 +645,12 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
ptr->dev_tensor().raw_ptr());
// update tensor desc for static infer
if (dest->desc.layout.ndim) {
mgb_assert(
dest->desc.layout.eq_shape(ptr->layout()),
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
ptr->layout().to_string().c_str());
}
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册