From 248d8bf0dcfe98811c58c44b7cc58634ad28dbaf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 15 Dec 2020 11:58:24 +0800 Subject: [PATCH] feat(imperative/ops): improve infer attrs validate function GitOrigin-RevId: 6fab3b140220709c6edf92bd5a59105c96c2320a --- imperative/src/impl/interpreter_impl.cpp | 6 +++--- imperative/src/impl/ops/backward_graph.cpp | 2 +- imperative/src/impl/ops/batch_norm.cpp | 11 ++++------- imperative/src/impl/ops/broadcast.cpp | 8 ++++---- imperative/src/impl/ops/cond_take.cpp | 2 +- imperative/src/impl/ops/elemwise.cpp | 2 +- imperative/src/impl/ops/tensor_manip.cpp | 2 +- 7 files changed, 15 insertions(+), 18 deletions(-) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 54ebd96e9..24e47afd4 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -176,7 +176,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { m_buffer.enqueue(Flush{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - return bool(info->ptr); + return static_cast(info->ptr); }); m_waitee = nullptr; TensorShape ret = info->ptr->layout(); @@ -212,7 +212,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { m_buffer.enqueue(Flush{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - return bool(info->ptr); + return static_cast(info->ptr); }); m_waitee = nullptr; return info->ptr->dev_tensor(); @@ -232,7 +232,7 @@ void ChannelImpl::close() { } void ChannelImpl::config_async_level(int level) { - mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); + mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2"); m_async_level = level; } diff --git a/imperative/src/impl/ops/backward_graph.cpp b/imperative/src/impl/ops/backward_graph.cpp index c20432ac9..6b729f43c 100644 --- a/imperative/src/impl/ops/backward_graph.cpp +++ b/imperative/src/impl/ops/backward_graph.cpp @@ -49,7 +49,7 @@ std::tuple, bool> BackwardGraph::InternalGraph::i expr_input_descs.push_back(node2attr.at(inp)); } - auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( + auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( *expr_op, expr_input_descs); validated = validated && expr_validated; diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index 2ca8c7602..e1db40ae0 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -54,16 +54,13 @@ std::tuple, bool> infer_output_attrs_fallible( SmallVector out_shapes(nr_out); auto&& i0 = inputs[0]; auto&& i1 = inputs[1]; - size_t i = 0; - if (!need_stat) { - out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node}; - i = 2; - } - for (; i < nr_out-1; ++ i) { + // [running_mean, running_var,] save_mean, save_var + for (size_t i = 0; i < nr_out-1; ++ i) { out_shapes[i] = {i1.layout, i1.comp_node}; } + // output tensor out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; - return {out_shapes, true}; + return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; } OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 16984c4d9..873a9693a 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -61,17 +61,17 @@ std::tuple, bool> infer_output_attrs_fallible( TensorLayout out_layout = src.layout; if (tshp.layout.ndim == 0 || tshp.value.empty()) { out_layout.ndim = 0; - return {{{out_layout, src.comp_node}}, true}; + return {{{out_layout, src.comp_node}}, false}; } mgb_assert( - tshp.layout.ndim == 1, - "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + tshp.layout.ndim == 1, + "target shape of Broadcast expects ndim=1; got ndim=%lu actually", tshp.layout.ndim); size_t target_ndim = tshp.layout.shape[0]; out_layout.ndim = target_ndim; auto* ptr = tshp.value.ptr(); - for(size_t i=0; i, bool> infer_output_attrs_fallible( return {{ {TensorLayout(inputs[0].layout.dtype), cn}, {TensorLayout(dtype::Int32()), cn} - }, true}; + }, false}; } OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index e32c10d84..2a30f544b 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -60,7 +60,7 @@ std::tuple, bool> infer_output_attrs_fallible( TensorLayout out_layout; out_layout.ndim = 0; out_layout.dtype = out_dt; - return {{{out_layout, out_cn}}, true}; + return {{{out_layout, out_cn}}, false}; } } diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index e4a8344f2..a52189617 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -59,7 +59,7 @@ std::tuple, bool> infer_output_attrs_fallible( mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& desc = inputs[0]; if (!desc.layout.ndim) { - return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; + return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; } DeviceTensorND value; if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ -- GitLab