提交 2d72de8a 编写于 作者: M Megvii Engine Team

fix(mge): fix infer output attrs fallible

GitOrigin-RevId: ea18c7f753aa1b48fd7ff66f7035c70b4745164d
上级 b6ad4572
......@@ -1509,7 +1509,7 @@ def sync_batch_norm(
"""
_eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() or training):
if _eps_mode == "additive" and not (is_distributed() and training):
return batch_norm(
inp,
running_mean,
......
......@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
ptr->to_contiguous_inplace();
}
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size();
dest->ptr = std::move(ptr);
......
......@@ -205,6 +205,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
size_t size = inputs.size();
SmallVector<LogicalTensorDesc> dests(size);
for (size_t i = 0; i < size; i++) {
if (inputs[i].layout.ndim == 0) {
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}},
false};
}
}
if (size > 1) {
auto [output_descs, validated] =
proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
......
......@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout;
if (!src[i].ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
}
megdnn::Elemwise::deduce_shape(src, shp);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册