提交 1b6e5fac 编写于 作者: Y yangjie159

fix subgraph inputs and add ut

上级 6c4b4f91
......@@ -42,11 +42,16 @@ int LiteKernel::DecOutTensorRefCount() {
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> input_kernels;
for (const auto kernel : kernels) {
for (auto input : kernel->in_kernels()) {
for (const auto &kernel : kernels) {
if (kernel->in_kernels().empty() && !kernel->in_tensors().empty()) {
input_kernels.emplace_back(kernel);
continue;
}
for (const auto &input : kernel->in_kernels()) {
auto iter = std::find(kernels.begin(), kernels.end(), input);
if (iter == kernels.end()) {
input_kernels.emplace_back(input);
auto item = std::find(input_kernels.begin(), input_kernels.end(), kernel);
if (iter == kernels.end() && item == input_kernels.end()) {
input_kernels.emplace_back(kernel);
}
}
}
......@@ -56,11 +61,16 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> output_kernels;
for (const auto kernel : kernels) {
for (const auto output : kernel->out_kernels()) {
for (const auto &kernel : kernels) {
if (kernel->out_kernels().empty() && !kernel->out_tensors().empty()) {
output_kernels.emplace_back(kernel);
continue;
}
for (const auto &output : kernel->out_kernels()) {
auto iter = std::find(kernels.begin(), kernels.end(), output);
if (iter == kernels.end()) {
output_kernels.emplace_back(output);
auto item = std::find(output_kernels.begin(), output_kernels.end(), kernel);
if (iter == kernels.end() && item == output_kernels.end()) {
output_kernels.emplace_back(kernel);
}
}
}
......
......@@ -289,6 +289,7 @@ set(TEST_SRC
${TEST_DIR}/main.cc
${TEST_DIR}/ut/src/runtime/kernel/arm/common/pack_tests.cc
${TEST_DIR}/ut/src/infer_test.cc
${TEST_DIR}/ut/src/utils_test.cc
)
if (SUPPORT_TRAIN)
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace mindspore {
class UtilsTest : public mindspore::CommonTest {
public:
UtilsTest() {}
};
TEST_F(UtilsTest, TestSubgraph) {
auto kernel0 = std::make_shared<kernel::LiteKernel>();
auto kernel1 = std::make_shared<kernel::LiteKernel>();
auto kernel2 = std::make_shared<kernel::LiteKernel>();
auto tensor0 = std::make_shared<lite::tensor::Tensor>();
auto tensor1 = std::make_shared<lite::tensor::Tensor>();
auto tensor2 = std::make_shared<lite::tensor::Tensor>();
auto tensor3 = std::make_shared<lite::tensor::Tensor>();
auto tensor4 = std::make_shared<lite::tensor::Tensor>();
kernel0->AddOutKernel(kernel1.get());
kernel1->AddInKernel(kernel0.get());
kernel1->AddOutKernel(kernel2.get());
kernel2->AddInKernel(kernel1.get());
kernel0->set_in_tensors({tensor0.get(), tensor1.get()});
kernel0->set_out_tensors({tensor2.get()});
kernel1->set_in_tensors({tensor2.get()});
kernel1->set_out_tensors({tensor3.get()});
kernel2->set_in_tensors({tensor3.get()});
kernel2->set_out_tensors({tensor4.get()});
std::vector<kernel::LiteKernel *> kernels = {kernel0.get(), kernel1.get(), kernel2.get()};
auto input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels);
ASSERT_EQ(input_kernels.size(), 1);
auto output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels);
ASSERT_EQ(output_kernels.size(), 1);
auto input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels);
ASSERT_EQ(input_tensors.size(), 2);
auto output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels);
ASSERT_EQ(output_tensors.size(), 1);
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册