diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index fa298575a6282e15dceafe2095d402698d2e1f55..5d0cda8ebf5a4352f2a7e122083298206b55fca1 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -42,11 +42,16 @@ int LiteKernel::DecOutTensorRefCount() { std::vector LiteKernelUtil::SubgraphInputKernels( const std::vector &kernels) { std::vector input_kernels; - for (const auto kernel : kernels) { - for (auto input : kernel->in_kernels()) { + for (const auto &kernel : kernels) { + if (kernel->in_kernels().empty() && !kernel->in_tensors().empty()) { + input_kernels.emplace_back(kernel); + continue; + } + for (const auto &input : kernel->in_kernels()) { auto iter = std::find(kernels.begin(), kernels.end(), input); - if (iter == kernels.end()) { - input_kernels.emplace_back(input); + auto item = std::find(input_kernels.begin(), input_kernels.end(), kernel); + if (iter == kernels.end() && item == input_kernels.end()) { + input_kernels.emplace_back(kernel); } } } @@ -56,11 +61,16 @@ std::vector LiteKernelUtil::SubgraphInputKernels( std::vector LiteKernelUtil::SubgraphOutputKernels( const std::vector &kernels) { std::vector output_kernels; - for (const auto kernel : kernels) { - for (const auto output : kernel->out_kernels()) { + for (const auto &kernel : kernels) { + if (kernel->out_kernels().empty() && !kernel->out_tensors().empty()) { + output_kernels.emplace_back(kernel); + continue; + } + for (const auto &output : kernel->out_kernels()) { auto iter = std::find(kernels.begin(), kernels.end(), output); - if (iter == kernels.end()) { - output_kernels.emplace_back(output); + auto item = std::find(output_kernels.begin(), output_kernels.end(), kernel); + if (iter == kernels.end() && item == output_kernels.end()) { + output_kernels.emplace_back(kernel); } } } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index e50e6123b320f4056992884e37fb514abe526986..74b0a090dce58b34d308aa566e08d8a7e9ac6681 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -289,6 +289,7 @@ set(TEST_SRC ${TEST_DIR}/main.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/common/pack_tests.cc ${TEST_DIR}/ut/src/infer_test.cc + ${TEST_DIR}/ut/src/utils_test.cc ) if (SUPPORT_TRAIN) diff --git a/mindspore/lite/test/ut/src/utils_test.cc b/mindspore/lite/test/ut/src/utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9b80ca94fab5a9d2dd8365e2320968b09acbbb2 --- /dev/null +++ b/mindspore/lite/test/ut/src/utils_test.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/schema/inner/model_generated.h" +#include "mindspore/lite/include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "mindspore/core/utils/log_adapter.h" +#include "mindspore/lite/src/lite_kernel.h" + +namespace mindspore { +class UtilsTest : public mindspore::CommonTest { + public: + UtilsTest() {} +}; + +TEST_F(UtilsTest, TestSubgraph) { + auto kernel0 = std::make_shared(); + auto kernel1 = std::make_shared(); + auto kernel2 = std::make_shared(); + + auto tensor0 = std::make_shared(); + auto tensor1 = std::make_shared(); + auto tensor2 = std::make_shared(); + auto tensor3 = std::make_shared(); + auto tensor4 = std::make_shared(); + + kernel0->AddOutKernel(kernel1.get()); + kernel1->AddInKernel(kernel0.get()); + kernel1->AddOutKernel(kernel2.get()); + kernel2->AddInKernel(kernel1.get()); + + kernel0->set_in_tensors({tensor0.get(), tensor1.get()}); + kernel0->set_out_tensors({tensor2.get()}); + kernel1->set_in_tensors({tensor2.get()}); + kernel1->set_out_tensors({tensor3.get()}); + kernel2->set_in_tensors({tensor3.get()}); + kernel2->set_out_tensors({tensor4.get()}); + + std::vector kernels = {kernel0.get(), kernel1.get(), kernel2.get()}; + + auto input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); + ASSERT_EQ(input_kernels.size(), 1); + auto output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); + ASSERT_EQ(output_kernels.size(), 1); + auto input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); + ASSERT_EQ(input_tensors.size(), 2); + auto output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); + ASSERT_EQ(output_tensors.size(), 1); +} +} // namespace mindspore