From fe1680b37854e02744d7b7bc34424298b56be993 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 22 Jan 2021 12:09:05 +0800 Subject: [PATCH] fix(imperative/ops): improve infer_output_attrs for broadcast GitOrigin-RevId: 6b7ed5576947a7614b7ad00bbfcbc1d5e520a33a --- imperative/src/impl/ops/broadcast.cpp | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 4602ecd2..3ea2e870 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -58,10 +58,10 @@ std::tuple, bool> infer_output_attrs_fallible( auto&& src = inputs[0]; auto&& tshp = inputs[1]; - TensorLayout out_layout = src.layout; + TensorShape out_shape; if (tshp.layout.ndim == 0 || tshp.value.empty()) { - out_layout.ndim = 0; - return {{{out_layout, src.comp_node}}, false}; + out_shape.ndim = 0; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; } mgb_assert( tshp.layout.ndim == 1, @@ -69,17 +69,17 @@ std::tuple, bool> infer_output_attrs_fallible( tshp.layout.ndim); size_t target_ndim = tshp.layout.shape[0]; - out_layout.ndim = target_ndim; + out_shape.ndim = target_ndim; auto* ptr = tshp.value.ptr(); for (size_t i = 0; i < target_ndim; ++i) { - out_layout.shape[i] = ptr[i]; + out_shape[i] = ptr[i]; } - mgb_assert(valid_broadcast(src.layout, out_layout), + mgb_assert(valid_broadcast(src.layout, out_shape), "the input shape %s can not be broadcasted to target shape %s", - src.layout.TensorShape::to_string().c_str(), - out_layout.TensorShape::to_string().c_str()); + src.layout.to_string().c_str(), + out_shape.to_string().c_str()); - return {{{out_layout, src.comp_node}}, true}; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) @@ -108,10 +108,10 @@ std::tuple, bool> infer_output_attrs_fallible( auto&& src = inputs[0]; auto&& tshp = inputs[1]; - TensorLayout out_layout = src.layout; + TensorShape out_shape; if (tshp.layout.ndim == 0 || tshp.value.empty()) { - out_layout.ndim = 0; - return {{{out_layout, src.comp_node}}, false}; + out_shape.ndim = 0; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; } mgb_assert( tshp.layout.ndim == 1, @@ -119,31 +119,31 @@ std::tuple, bool> infer_output_attrs_fallible( tshp.layout.ndim); size_t target_ndim = tshp.layout.shape[0]; - out_layout.ndim = target_ndim; + out_shape.ndim = target_ndim; auto* ptr = tshp.value.ptr(); for (size_t i = 0; i < target_ndim; ++i) { - out_layout.shape[i] = ptr[i]; + out_shape[i] = ptr[i]; } if (src.layout.ndim == 0) { - return {{{out_layout, src.comp_node}}, false}; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; } if (op.axis != opr::Reshape::Param::INVALID_AXIS) { - mgb_assert(out_layout.shape[op.axis] == -1); - out_layout.shape[op.axis] = 1; - mgb_assert(src.layout.total_nr_elems() % out_layout.total_nr_elems() == 0, + mgb_assert(out_shape[op.axis] == -1); + out_shape[op.axis] = 1; + mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0, "can not reshape from %s to %s", src.layout.to_string().c_str(), - out_layout.to_string().c_str()); - out_layout.shape[op.axis] = src.layout.total_nr_elems() / out_layout.total_nr_elems(); + out_shape.to_string().c_str()); + out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems(); } else { - mgb_assert(src.layout.total_nr_elems() == out_layout.total_nr_elems(), + mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(), "can not reshape from %s to %s", src.layout.to_string().c_str(), - out_layout.to_string().c_str()); + out_shape.to_string().c_str()); } - return {{{out_layout, src.comp_node}}, true}; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } OP_TRAIT_REG(Reshape, Reshape) -- GitLab