提交 5bf31163 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

fix(mge): fix infer output attrs fallible

GitOrigin-RevId: ea18c7f753aa1b48fd7ff66f7035c70b4745164d
上级 94960ecf
...@@ -1509,7 +1509,7 @@ def sync_batch_norm( ...@@ -1509,7 +1509,7 @@ def sync_batch_norm(
""" """
_eps_mode = eps_mode.lower() _eps_mode = eps_mode.lower()
assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
if _eps_mode == "additive" and not (is_distributed() or training): if _eps_mode == "additive" and not (is_distributed() and training):
return batch_norm( return batch_norm(
inp, inp,
running_mean, running_mean,
......
...@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { ...@@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
ptr->to_contiguous_inplace(); ptr->to_contiguous_inplace();
} }
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node(); dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size(); dest->memory = ptr->blob()->size();
dest->ptr = std::move(ptr); dest->ptr = std::move(ptr);
......
...@@ -205,6 +205,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -205,6 +205,12 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
size_t size = inputs.size(); size_t size = inputs.size();
SmallVector<LogicalTensorDesc> dests(size); SmallVector<LogicalTensorDesc> dests(size);
for (size_t i = 0; i < size; i++) {
if (inputs[i].layout.ndim == 0) {
return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}},
false};
}
}
if (size > 1) { if (size > 1) {
auto [output_descs, validated] = auto [output_descs, validated] =
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
......
...@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -115,6 +115,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorShapeArray src(inputs.size()); TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout; src[i] = inputs[i].layout;
if (!src[i].ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
} }
megdnn::Elemwise::deduce_shape(src, shp); megdnn::Elemwise::deduce_shape(src, shp);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册