提交 f9e8ec3f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4039 add ut for 21 tflite op parsers

Merge pull request !4039 from wangzhe/master
......@@ -191,8 +191,12 @@ if (SUPPORT_GPU)
endif()
### converter
if(BUILD_CONVERTER)
file(GLOB_RECURSE TEST_CASE_TFLITE_PARSERS_SRC
${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc
)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${TEST_CASE_TFLITE_PARSERS_SRC}
${TOP_DIR}/mindspore/core/utils/flags.cc
${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/src/common/anf_importer/anf_importer.cc
......
......@@ -7,6 +7,7 @@ mkdir -pv ${CUR_DIR}/do_test
cd ${CUR_DIR}/do_test
cp ${BUILD_DIR}/test/lite-test ./
cp -r ${CUR_DIR}/ut/src/runtime/kernel/arm/test_data/* ./
cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./
## prepare data for dataset
TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/
cp -fr $TEST_DATA_DIR/testPK ./data
......@@ -14,6 +15,8 @@ cp -fr $TEST_DATA_DIR/testPK ./data
./lite-test --gtest_filter="*MindDataTestTensorDE*"
./lite-test --gtest_filter="*MindDataTestEager*"
./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="*TestHebing*"
./lite-test --gtest_filter=TestFcFp32*
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserAbs : public TestTfliteParser {
public:
TestTfliteParserAbs() {}
void SetUp() override { meta_graph = LoadAndConvert("./abs.tflite", ""); }
};
TEST_F(TestTfliteParserAbs, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Abs) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserBatchToSpaceNd : public TestTfliteParser {
public:
TestTfliteParserBatchToSpaceNd() {}
void SetUp() override { meta_graph = LoadAndConvert("./batch_to_space_nd.tflite"); }
};
TEST_F(TestTfliteParserBatchToSpaceNd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_BatchToSpace) << "wrong Op Type";
}
TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) {
const std::vector<int> blockShape{2, 2};
const std::vector<int> crops{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserCos : public TestTfliteParser {
public:
TestTfliteParserCos() {}
void SetUp() override { meta_graph = LoadAndConvert("./cos.tflite", ""); }
};
TEST_F(TestTfliteParserCos, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Cos) << "wrong Op Type";
}
} // namespace mindspore
......@@ -13,31 +13,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "tools/converter/converter_flags.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/parser/tflite/tflite_converter.h"
#include "tools/converter/parser/tflite/tflite_exp_parser.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
namespace mindspore {
class TestTfliteExpParser : public mindspore::Common {
class TestTfliteParserExp : public TestTfliteParser {
public:
TestTfliteExpParser() {}
TestTfliteParserExp() {}
void SetUp() override { meta_graph = LoadAndConvert("./exp.tflite", ""); }
};
TEST_F(TestTfliteExpParser, ExpParser) {
lite::converter::Flags flags;
flags.modelFile = "./test_data/Exp.tflite";
flags.fmk = lite::converter::FmkType_TFLITE;
lite::TfliteConverter converter;
schema::MetaGraphT *fb_graph = nullptr;
fb_graph = converter.Convert(&flags);
const auto &nodes = fb_graph->nodes;
nodes.back();
TEST_F(TestTfliteParserExp, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Exp) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserLog : public TestTfliteParser {
public:
TestTfliteParserLog() {}
void SetUp() override { meta_graph = LoadAndConvert("./log.tflite", ""); }
};
TEST_F(TestTfliteParserLog, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Log) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteLogicalParserAnd : public TestTfliteParser {
public:
TestTfliteLogicalParserAnd() {}
void SetUp() override { meta_graph = LoadAndConvert("./logical_and.tflite", ""); }
};
TEST_F(TestTfliteLogicalParserAnd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalAnd) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserLogicalNot : public TestTfliteParser {
public:
TestTfliteParserLogicalNot() {}
void SetUp() override { meta_graph = LoadAndConvert("./logical_not.tflite", ""); }
};
TEST_F(TestTfliteParserLogicalNot, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalNot) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserLogicalOr : public TestTfliteParser {
public:
TestTfliteParserLogicalOr() {}
void SetUp() override { meta_graph = LoadAndConvert("./logical_or.tflite", ""); }
};
TEST_F(TestTfliteParserLogicalOr, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LogicalOr) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserMaximum : public TestTfliteParser {
public:
TestTfliteParserMaximum() {}
void SetUp() override { meta_graph = LoadAndConvert("./maximum.tflite"); }
};
TEST_F(TestTfliteParserMaximum, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Maximum) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserMinimum : public TestTfliteParser {
public:
TestTfliteParserMinimum() {}
void SetUp() override { meta_graph = LoadAndConvert("./minimum.tflite"); }
};
TEST_F(TestTfliteParserMinimum, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Minimum) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserOneHot : public TestTfliteParser {
public:
TestTfliteParserOneHot() {}
void SetUp() override { meta_graph = LoadAndConvert("./one_hot.tflite"); }
};
TEST_F(TestTfliteParserOneHot, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_OneHot) << "wrong Op Type";
}
TEST_F(TestTfliteParserOneHot, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr);
// in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size()
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h"
namespace mindspore {
schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) {
schema::MetaGraphT *meta_graph = nullptr;
lite::TfliteModelParser parser;
meta_graph = parser.Parse(model_path, weight_path);
return meta_graph;
}
void TestTfliteParser::TearDown() { free(meta_graph); }
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_
#define MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_
#include <string>
#include "common/common_test.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
class TestTfliteParser : public Common {
public:
TestTfliteParser() {}
void TearDown() override;
schema::MetaGraphT *LoadAndConvert(const std::string &model_path, const std::string &weight_path = "");
schema::MetaGraphT *meta_graph;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_TEST_UT_TOOLS_CONVERTER_PARSER_TFLITE_TFLITE_PARSERS_TEST_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserReduceMax : public TestTfliteParser {
public:
TestTfliteParserReduceMax() {}
void SetUp() override { meta_graph = LoadAndConvert("./reduce_max.tflite"); }
};
TEST_F(TestTfliteParserReduceMax, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type";
}
TEST_F(TestTfliteParserReduceMax, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceMax)
<< "wrong reduce mode";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserReduceMin : public TestTfliteParser {
public:
TestTfliteParserReduceMin() {}
void SetUp() override { meta_graph = LoadAndConvert("./reduce_min.tflite"); }
};
TEST_F(TestTfliteParserReduceMin, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type";
}
TEST_F(TestTfliteParserReduceMin, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceMin)
<< "wrong reduce mode";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserReduceProd : public TestTfliteParser {
public:
TestTfliteParserReduceProd() {}
void SetUp() override { meta_graph = LoadAndConvert("./reduce_prod.tflite"); }
};
TEST_F(TestTfliteParserReduceProd, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type";
}
TEST_F(TestTfliteParserReduceProd, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceProd)
<< "wrong reduce mode";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserRsqrt : public TestTfliteParser {
public:
TestTfliteParserRsqrt() {}
void SetUp() override { meta_graph = LoadAndConvert("./rsqrt.tflite", ""); }
};
TEST_F(TestTfliteParserRsqrt, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Rsqrt) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSin : public TestTfliteParser {
public:
TestTfliteParserSin() {}
void SetUp() override { meta_graph = LoadAndConvert("./sin.tflite", ""); }
};
TEST_F(TestTfliteParserSin, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sin) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSplit : public TestTfliteParser {
public:
TestTfliteParserSplit() {}
void SetUp() override { meta_graph = LoadAndConvert("./split.tflite"); }
};
TEST_F(TestTfliteParserSplit, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Split) << "wrong Op Type";
}
TEST_F(TestTfliteParserSplit, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{2, 2};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSplitV : public TestTfliteParser {
public:
TestTfliteParserSplitV() {}
void SetUp() override { meta_graph = LoadAndConvert("./split_v.tflite"); }
};
TEST_F(TestTfliteParserSplitV, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Split) << "wrong Op Type";
}
TEST_F(TestTfliteParserSplitV, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr);
const std::vector<int> sizeSplits{1, 3};
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSqrt : public TestTfliteParser {
public:
TestTfliteParserSqrt() {}
void SetUp() override { meta_graph = LoadAndConvert("./sqrt.tflite", ""); }
};
TEST_F(TestTfliteParserSqrt, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sqrt) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSquare : public TestTfliteParser {
public:
TestTfliteParserSquare() {}
void SetUp() override { meta_graph = LoadAndConvert("./square.tflite", ""); }
};
TEST_F(TestTfliteParserSquare, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Square) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSum : public TestTfliteParser {
public:
TestTfliteParserSum() {}
void SetUp() override { meta_graph = LoadAndConvert("./sum.tflite"); }
};
TEST_F(TestTfliteParserSum, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reduce) << "wrong Op Type";
}
TEST_F(TestTfliteParserSum, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->mode, schema::ReduceMode_ReduceSum)
<< "wrong reduce mode";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->keepDims, false);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes.size(), 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReduce()->axes[0], 2);
}
} // namespace mindspore
......@@ -24,12 +24,13 @@ STATUS TfliteBatchToSpaceNDParser::Parse(const std::unique_ptr<tflite::OperatorT
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantizedModel) {
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(INFO) << "parse TfliteBatchToSpaceNDParser";
std::unique_ptr<schema::BatchToSpaceT> attr(new schema::BatchToSpaceT());
// in tflite
// blockShape should be a 1D tensor with dimension [spatial_dims_num]
// crops should be a 2D tensor with dimension [spatial_dims_num, 2]
if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) {
MS_LOG(ERROR) << "BatchToSpaceNd get blockShape attr failed";
return RET_ERROR;
......
......@@ -32,8 +32,12 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
MS_LOG(ERROR) << "get op: " << op->name << " attr failed";
return RET_NULL_PTR;
}
attr->axis = tflite_attr->axis;
auto axis = tflite_attr->axis;
const auto tensor_shape = tfliteTensors[tfliteOp->inputs[0]].get()->shape;
if (axis < 0) {
axis += tensor_shape.size();
}
attr->axis = axis;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
......
......@@ -39,6 +39,8 @@ STATUS TfliteScatterNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
MS_LOG(DEBUG) << i;
}
*/
// in tflite, kIndices = 0, kUpdates = 1, kShape = 2
// in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2;
std::swap(op->inputIndex[0], op->inputIndex[2]);
std::swap(op->inputIndex[1], op->inputIndex[2]);
/*
......
......@@ -54,7 +54,7 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
}
attr->numberSplit = num_splits;
for (int i = 0; i <= num_splits; i++) {
for (int i = 0; i < num_splits; i++) {
attr->sizeSplits.push_back(tensor_shape[axis] / num_splits);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册