提交 1c473559 编写于 作者: A Advait Jain 提交者: TensorFlower Gardener

Do not use the deprecated_builtin_code when creating flatbuffers.

This change avoids an implicit conversion from enum BuiltinOperator to int8_t
for calls to CreateOperatorCodeDirect. This should mean that once the
BuiltinOperator enum is larger than a byte, the TFLM code does not get
tripped.

PiperOrigin-RevId: 339910848
Change-Id: Ief978f2bb9ea818f8211ade3927baa35dc7a8512
上级 e77b0dc6
......@@ -86,8 +86,7 @@ class ModelBuilder {
: builder_(builder) {}
// Registers an operator that will be used in the model.
Operator RegisterOp(BuiltinOperator op, const char* custom_code,
int32_t version);
Operator RegisterOp(BuiltinOperator op, const char* custom_code);
// Adds a tensor to the model.
Tensor AddTensor(TensorType type, std::initializer_list<int32_t> shape) {
......@@ -146,11 +145,10 @@ class ModelBuilder {
};
ModelBuilder::Operator ModelBuilder::RegisterOp(BuiltinOperator op,
const char* custom_code,
int32_t version) {
const char* custom_code) {
TFLITE_DCHECK(next_operator_code_id_ <= kMaxOperatorCodes);
operator_codes_[next_operator_code_id_] =
tflite::CreateOperatorCodeDirect(*builder_, op, custom_code, version);
operator_codes_[next_operator_code_id_] = tflite::CreateOperatorCodeDirect(
*builder_, /*deprecated_builtin_code=*/0, custom_code, /*version=*/0, op);
next_operator_code_id_++;
return next_operator_code_id_ - 1;
}
......@@ -262,7 +260,7 @@ const Model* BuildSimpleStatefulModel() {
ModelBuilder model_builder(fb_builder);
const int op_id =
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "simple_stateful_op", 0);
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "simple_stateful_op");
const int input_tensor = model_builder.AddTensor(TensorType_UINT8, {3});
const int median_tensor = model_builder.AddTensor(TensorType_UINT8, {3});
const int invoke_count_tensor =
......@@ -303,8 +301,7 @@ const Model* BuildSimpleModelWithBranch() {
v
*/
const int op_id =
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "mock_custom",
/* version= */ 0);
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "mock_custom");
const int t0 = model_builder.AddTensor(TensorType_FLOAT32, {2, 2, 3});
const int t1 = model_builder.AddTensor(TensorType_FLOAT32, {2, 2, 3});
const int t2 = model_builder.AddTensor(TensorType_FLOAT32, {2, 2, 3});
......@@ -326,8 +323,7 @@ const Model* BuildModelWithOfflinePlanning(int number_of_tensors,
ModelBuilder model_builder(fb_builder);
const int op_id =
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "mock_custom",
/* version= */ 0);
model_builder.RegisterOp(BuiltinOperator_CUSTOM, "mock_custom");
for (int i = 0; i < number_of_tensors; ++i) {
model_builder.AddTensor(TensorType_FLOAT32, {2, 2, 3});
......@@ -408,8 +404,9 @@ const Model* BuildSimpleMockModel() {
builder->CreateString("test_subgraph"))};
constexpr size_t operator_codes_size = 1;
const Offset<OperatorCode> operator_codes[operator_codes_size] = {
CreateOperatorCodeDirect(*builder, BuiltinOperator_CUSTOM, "mock_custom",
0)};
CreateOperatorCodeDirect(*builder, /*deprecated_builtin_code=*/0,
"mock_custom",
/*version=*/0, BuiltinOperator_CUSTOM)};
const Offset<Model> model_offset = CreateModel(
*builder, 0, builder->CreateVector(operator_codes, operator_codes_size),
builder->CreateVector(subgraphs, subgraphs_size),
......@@ -557,8 +554,9 @@ const Model* BuildComplexMockModel() {
constexpr size_t operator_codes_size = 1;
const Offset<OperatorCode> operator_codes[operator_codes_size] = {
CreateOperatorCodeDirect(*builder, BuiltinOperator_CUSTOM, "mock_custom",
0)};
CreateOperatorCodeDirect(*builder, /*deprecated_builtin_code=*/0,
"mock_custom",
/*version=*/0, BuiltinOperator_CUSTOM)};
const Offset<Model> model_offset = CreateModel(
*builder, 0, builder->CreateVector(operator_codes, operator_codes_size),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册