提交 22504523 编写于 作者: M Megvii Engine Team

perf(imperative): improve shape inference

GitOrigin-RevId: 98b4d7e9aff5a992306ed6d581390718a0a24caa
上级 df3474ca
......@@ -178,7 +178,11 @@ TensorInfo* ChannelImpl::put_impl(
auto _ = StackManager::Guard{"Put", &state.stack_manager};
auto info = alloc();
MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
constexpr int size_threshold = TensorShape::MAX_NDIM;
init(info, {data.layout(), data.comp_node()});
if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
info->desc.value = hvalue.proxy_to_default_cpu();
}
info->ptr = Tensor::make(data, hvalue);
MGB_RECORD_EVENT(
TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
......
......@@ -58,10 +58,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return proxy_graph_detail::apply_on_physical_tensor(def, inputs);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto [output_descs, validated] =
proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
if (inputs.size() == 2 && !output_descs[0].layout.ndim) {
if (!inputs[1].value.empty()) {
cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value);
output_descs[0].layout.init_contiguous_stride();
}
}
return {output_descs, validated};
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace reduce
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册