提交 bf97e6a3 编写于 作者: X xuanyue

correct add/sub/div/mul/addn/cast/unique and add ut

上级 68fc7c2c
......@@ -384,15 +384,19 @@ table Eltwise {
}
table Add {
activationType : ActivationType;
}
table Sub {
activationType : ActivationType;
}
table Mul {
activationType : ActivationType;
}
table Div {
activationType : ActivationType;
}
table AddGrad {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserAddN : public TestTfliteParser {
public:
TestTfliteParserAddN() {}
void SetUp() override {
meta_graph = LoadAndConvert("./addn.tflite");
}
};
TEST_F(TestTfliteParserAddN, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_AddN) << "wrong Op Type";
}
TEST_F(TestTfliteParserAddN, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserCast : public TestTfliteParser {
public:
TestTfliteParserCast() {}
void SetUp() override {
meta_graph = LoadAndConvert("./cast.tflite");
}
};
TEST_F(TestTfliteParserCast, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Cast) << "wrong Op Type";
}
TEST_F(TestTfliteParserCast, AttrValue) {
// float32 --> int32
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserDepthToSpace : public TestTfliteParser {
public:
TestTfliteParserDepthToSpace() {}
void SetUp() override {
meta_graph = LoadAndConvert("./depth_to_space.tflite");
}
};
TEST_F(TestTfliteParserDepthToSpace, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthToSpace) << "wrong Op Type";
}
TEST_F(TestTfliteParserDepthToSpace, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserEqual : public TestTfliteParser {
public:
TestTfliteParserEqual() {}
void SetUp() override {
meta_graph = LoadAndConvert("./equal.tflite");
}
};
TEST_F(TestTfliteParserEqual, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Equal) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserGreaterEqual : public TestTfliteParser {
public:
TestTfliteParserGreaterEqual() {}
void SetUp() override {
meta_graph = LoadAndConvert("./greater_equal.tflite");
}
};
TEST_F(TestTfliteParserGreaterEqual, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GreaterEqual) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserGreater : public TestTfliteParser {
public:
TestTfliteParserGreater() {}
void SetUp() override {
meta_graph = LoadAndConvert("./greater.tflite");
}
};
TEST_F(TestTfliteParserGreater, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Greater) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserLessEqual : public TestTfliteParser {
public:
TestTfliteParserLessEqual() {}
void SetUp() override {
meta_graph = LoadAndConvert("./less_equal.tflite");
}
};
TEST_F(TestTfliteParserLessEqual, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_LessEqual) << "wrong Op Type";
}
} // namespace mindspore
......@@ -13,22 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserExp : public TestTfliteParser {
class TestTfliteParserLess : public TestTfliteParser {
public:
TestTfliteParserExp() {}
void SetUp() override { meta_graph = LoadAndConvert("./exp.tflite", ""); }
TestTfliteParserLess() {}
void SetUp() override {
meta_graph = LoadAndConvert("./less.tflite");
}
};
TEST_F(TestTfliteParserExp, OpType) {
TEST_F(TestTfliteParserLess, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Exp) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Less) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserNotEqual : public TestTfliteParser {
public:
TestTfliteParserNotEqual() {}
void SetUp() override {
meta_graph = LoadAndConvert("./not_equal.tflite");
}
};
TEST_F(TestTfliteParserNotEqual, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_NotEqual) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserPrelu : public TestTfliteParser {
public:
TestTfliteParserPrelu() {}
void SetUp() override {
meta_graph = LoadAndConvert("./prelu.tflite");
}
};
TEST_F(TestTfliteParserPrelu, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Prelu) << "wrong Op Type";
}
TEST_F(TestTfliteParserPrelu, AttrValue) {
std::vector<float> slope(20, 0);
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserReverseSequence : public TestTfliteParser {
public:
TestTfliteParserReverseSequence() {}
void SetUp() override {
meta_graph = LoadAndConvert("./reverse_sequence.tflite");
}
};
TEST_F(TestTfliteParserReverseSequence, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ReverseSequence) << "wrong Op Type";
}
TEST_F(TestTfliteParserReverseSequence, AttrValue) {
std::vector<int> seq_length{7, 2, 3, 5};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserRound : public TestTfliteParser {
public:
TestTfliteParserRound() {}
void SetUp() override {
meta_graph = LoadAndConvert("./round.tflite");
}
};
TEST_F(TestTfliteParserRound, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Round) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSpaceToBatchND : public TestTfliteParser {
public:
TestTfliteParserSpaceToBatchND() {}
void SetUp() override {
meta_graph = LoadAndConvert("./space_to_batch_nd.tflite");
}
};
TEST_F(TestTfliteParserSpaceToBatchND, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SpaceToBatchND) << "wrong Op Type";
}
TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) {
std::vector<int> blockshape{2, 2};
std::vector<int> padding{0, 0, 2, 0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSpaceToDepth : public TestTfliteParser {
public:
TestTfliteParserSpaceToDepth() {}
void SetUp() override {
meta_graph = LoadAndConvert("./space_to_depth.tflite");
}
};
TEST_F(TestTfliteParserSpaceToDepth, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SpaceToDepth) << "wrong Op Type";
}
TEST_F(TestTfliteParserSpaceToDepth, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSparseToDense : public TestTfliteParser {
public:
TestTfliteParserSparseToDense() {}
void SetUp() override {
meta_graph = LoadAndConvert("./sparse_to_dense.tflite");
}
};
TEST_F(TestTfliteParserSparseToDense, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SparseToDense) << "wrong Op Type";
}
TEST_F(TestTfliteParserSparseToDense, AttrValue) {
std::vector<int> outputShape{5, 5};
std::vector<int> sparseValue{1};
std::vector<int> defaultValue{0};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false);
}
} // namespace mindspore
......@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
......@@ -30,5 +31,4 @@ TEST_F(TestTfliteParserSquare, OpType) {
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Square) << "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserSquaredDifference : public TestTfliteParser {
public:
TestTfliteParserSquaredDifference() {}
void SetUp() override {
meta_graph = LoadAndConvert("./squared_difference.tflite");
}
};
TEST_F(TestTfliteParserSquaredDifference, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_SquaredDifference)
<< "wrong Op Type";
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserStridedSlice : public TestTfliteParser {
public:
TestTfliteParserStridedSlice() {}
void SetUp() override {
meta_graph = LoadAndConvert("./strided_slice.tflite");
}
};
TEST_F(TestTfliteParserStridedSlice, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_StridedSlice) << "wrong Op Type";
}
TEST_F(TestTfliteParserStridedSlice, AttrValue) {
std::vector<int> begin{1, -1, 0};
std::vector<int> end{2, -3, 3};
std::vector<int> stride{1, -1, 1};
std::vector<int> isscale{3, 2, 3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserTile : public TestTfliteParser {
public:
TestTfliteParserTile() {}
void SetUp() override {
meta_graph = LoadAndConvert("./tile.tflite");
}
};
TEST_F(TestTfliteParserTile, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Tile) << "wrong Op Type";
}
TEST_F(TestTfliteParserTile, AttrValue) {
std::vector<int> multiply{2, 3, 4};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserTopKV2 : public TestTfliteParser {
public:
TestTfliteParserTopKV2() {}
void SetUp() override {
meta_graph = LoadAndConvert("./topk_v2.tflite");
}
};
TEST_F(TestTfliteParserTopKV2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type";
}
TEST_F(TestTfliteParserTopKV2, AttrValue) {
// attr->sorted default is true
std::vector<int> k{3};
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true);
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserUnique : public TestTfliteParser {
public:
TestTfliteParserUnique() {}
void SetUp() override {
meta_graph = LoadAndConvert("./unique.tflite");
}
};
TEST_F(TestTfliteParserUnique, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Unique) << "wrong Op Type";
}
TEST_F(TestTfliteParserUnique, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32
}
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h"
#include <iostream>
#include "common/common_test.h"
namespace mindspore {
class TestTfliteParserUnstack : public TestTfliteParser {
public:
TestTfliteParserUnstack() {}
void SetUp() override {
meta_graph = LoadAndConvert("./unstack.tflite");
}
};
TEST_F(TestTfliteParserUnstack, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Unstack) << "wrong Op Type";
}
TEST_F(TestTfliteParserUnstack, AttrValue) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1);
}
} // namespace mindspore
......@@ -29,6 +29,11 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT());
const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
......@@ -36,7 +41,7 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
return RET_ERROR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add;
......
......@@ -30,7 +30,7 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_
bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteAddNParser";
std::unique_ptr<schema::AddNT> attr(new schema::AddNT());
attr->N = tflite_tensors.size();
attr->N = tflite_tensors.size() - 1;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_AddN;
......
......@@ -29,14 +29,9 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteCastParser";
std::unique_ptr<schema::CastT> attr(new schema::CastT());
const auto &tflite_attr = tflite_op->builtin_options.AsCastOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op:" << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->srcT = tflite_attr->in_data_type;
attr->dstT = tflite_attr->out_data_type;
attr->srcT = dtype_map[tflite_tensors[tflite_op->inputs[0]]->type];
attr->dstT = dtype_map[tflite_tensors[tflite_op->outputs[0]]->type];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
......
......@@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
......@@ -34,6 +35,19 @@ class TfliteCastParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
private:
std::map<int, TypeId> dtype_map = {
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64},
{tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64},
{tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16},
{tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
};
};
} // namespace lite
} // namespace mindspore
......
......@@ -42,7 +42,7 @@ STATUS TfliteDivParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Div;
......
......@@ -43,7 +43,7 @@ STATUS TfliteMulParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
MS_LOG(ERROR) << "parse weight failed";
return RET_ERROR;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Mul;
......
......@@ -29,11 +29,6 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr<tflite::OperatorT>
TensorCache *tensor_cache, bool quantized_model) {
MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser";
std::unique_ptr<schema::SparseToDenseT> attr(new schema::SparseToDenseT());
const auto &tflite_attr = tflite_op->builtin_options.AsSparseToDenseOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op:" << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->outputShape)) {
MS_LOG(ERROR) << "sparseToDense -> outputShape get failed";
......@@ -47,7 +42,7 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr<tflite::OperatorT>
MS_LOG(ERROR) << "sparseToDense -> defaultValue get failed";
return RET_ERROR;
}
attr->validateIndices = tflite_attr->validate_indices;
attr->validateIndices = false;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_SparseToDense;
......
......@@ -40,7 +40,7 @@ STATUS TfliteSubParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
MS_LOG(ERROR) << "parse weight failed";
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Sub;
......
......@@ -35,7 +35,7 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
attr->outType = GetTfliteDataType(tflite_attr->idx_out_type);
attr->outType = dtype_map[tflite_attr->idx_out_type];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Unique;
......
......@@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
......@@ -34,6 +35,19 @@ class TfliteUniqueParser : public TfliteNodeParser {
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op,
TensorCache *tensor_cache,
bool quantized_model) override;
private:
std::map<int, TypeId> dtype_map = {
{tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64},
{tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32},
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64},
{tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16},
{tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
};
};
} // namespace lite
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册