提交 439bf24e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Adds initial interface for GPU delegation, using Metal, to TensorFlow Lite Swift library.

PiperOrigin-RevId: 262483401
上级 3d55e8a1
// Copyright 2019 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import TensorFlowLiteC
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
public protocol Delegate: class {
/// `TFL_Delegate` C pointer type.
typealias CDelegate = OpaquePointer
/// Delegate that performs model computations.
var cDelegate: CDelegate? { get }
}
...@@ -17,8 +17,7 @@ import TensorFlowLiteC ...@@ -17,8 +17,7 @@ import TensorFlowLiteC
/// A TensorFlow Lite interpreter that performs inference from a given model. /// A TensorFlow Lite interpreter that performs inference from a given model.
public final class Interpreter { public final class Interpreter {
/// `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
/// The `TFL_Interpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
private typealias CInterpreter = OpaquePointer private typealias CInterpreter = OpaquePointer
/// Total number of input tensors associated with the model. /// Total number of input tensors associated with the model.
...@@ -31,15 +30,15 @@ public final class Interpreter { ...@@ -31,15 +30,15 @@ public final class Interpreter {
return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter)) return Int(TFL_InterpreterGetOutputTensorCount(cInterpreter))
} }
/// The underlying `TFL_Interpreter` C pointer. /// Underlying `TFL_Interpreter` C pointer.
private var cInterpreter: CInterpreter? private var cInterpreter: CInterpreter?
/// Creates a new model interpreter instance. /// Creates a new model interpreter instance.
/// ///
/// - Parameters: /// - Parameters:
/// - modelPath: Local file path to a TensorFlow Lite model. /// - modelPath: Local file path to a TensorFlow Lite model.
/// - options: Custom configurations for the interpreter. The default is `nil` indicating that /// - options: Custom configurations for the interpreter. Default is `nil` indicating that the
/// the interpreter will determine the configuration options. /// interpreter will determine the configuration options.
/// - Throws: An error if the model could not be loaded or the interpreter could not be created. /// - Throws: An error if the model could not be loaded or the interpreter could not be created.
public init(modelPath: String, options: InterpreterOptions? = nil) throws { public init(modelPath: String, options: InterpreterOptions? = nil) throws {
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel } guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
/// TensorFlow Lite interpreter errors. /// TensorFlow Lite interpreter errors.
public enum InterpreterError: Error { public enum InterpreterError: Error, Equatable, Hashable {
case invalidTensorIndex(index: Int, maxIndex: Int) case invalidTensorIndex(index: Int, maxIndex: Int)
case invalidTensorDataCount(provided: Int, required: Int) case invalidTensorDataCount(provided: Int, required: Int)
case invalidTensorDataType case invalidTensorDataType
...@@ -37,8 +37,8 @@ extension InterpreterError: LocalizedError { ...@@ -37,8 +37,8 @@ extension InterpreterError: LocalizedError {
switch self { switch self {
case .invalidTensorIndex(let index, let maxIndex): case .invalidTensorIndex(let index, let maxIndex):
return "Invalid tensor index \(index), max index is \(maxIndex)." return "Invalid tensor index \(index), max index is \(maxIndex)."
case .invalidTensorDataCount(let providedCount, let requiredCount): case .invalidTensorDataCount(let provided, let required):
return "Provided data count \(providedCount) must match the required count \(requiredCount)." return "Provided data count \(provided) must match the required count \(required)."
case .invalidTensorDataType: case .invalidTensorDataType:
return "Tensor data type is unsupported or could not be determined due to a model error." return "Tensor data type is unsupported or could not be determined due to a model error."
case .failedToLoadModel: case .failedToLoadModel:
...@@ -63,9 +63,5 @@ extension InterpreterError: LocalizedError { ...@@ -63,9 +63,5 @@ extension InterpreterError: LocalizedError {
extension InterpreterError: CustomStringConvertible { extension InterpreterError: CustomStringConvertible {
/// Textual representation of the TensorFlow Lite interpreter error. /// Textual representation of the TensorFlow Lite interpreter error.
public var description: String { public var description: String { return errorDescription ?? "Unknown error." }
return errorDescription ?? "Unknown error."
}
} }
extension InterpreterError: Equatable {}
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
/// Custom configuration options for a TensorFlow Lite `Interpreter`. /// Custom configuration options for a TensorFlow Lite `Interpreter`.
public struct InterpreterOptions: Equatable { public struct InterpreterOptions: Equatable {
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` which /// that the `Interpreter` will decide the number of threads to use.
/// indicates that the `Interpreter` will decide the number of threads to use.
public var threadCount: Int? = nil public var threadCount: Int? = nil
/// Creates a new instance of interpreter options. /// Creates a new instance of interpreter options.
......
...@@ -16,11 +16,10 @@ import TensorFlowLiteC ...@@ -16,11 +16,10 @@ import TensorFlowLiteC
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference. /// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
final class Model { final class Model {
/// `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
/// The `TFL_Model` C pointer type represented as an `UnsafePointer<TFL_Model>`.
typealias CModel = OpaquePointer typealias CModel = OpaquePointer
/// The underlying `TFL_Model` C pointer. /// Underlying `TFL_Model` C pointer.
let cModel: CModel? let cModel: CModel?
/// Creates a new model instance. /// Creates a new model instance.
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
/// Parameters that determine the mapping of quantized values to real values. Quantized values can /// Parameters that determine the mapping of quantized values to real values. Quantized values can
/// be mapped to float values using the following conversion: /// be mapped to float values using the following conversion:
/// `realValue = scale * (quantizedValue - zeroPoint)`. /// `realValue = scale * (quantizedValue - zeroPoint)`.
public struct QuantizationParameters { public struct QuantizationParameters: Equatable, Hashable {
/// Difference between real values corresponding to consecutive quantized values differing by 1. /// Difference between real values corresponding to consecutive quantized values differing by 1.
/// For example, the range of quantized values for `UInt8` data type is [0, 255]. /// For example, the range of quantized values for `UInt8` data type is [0, 255].
public let scale: Float public let scale: Float
......
...@@ -16,8 +16,7 @@ import Foundation ...@@ -16,8 +16,7 @@ import Foundation
import TensorFlowLiteC import TensorFlowLiteC
/// An input or output tensor in a TensorFlow Lite graph. /// An input or output tensor in a TensorFlow Lite graph.
public struct Tensor { public struct Tensor: Equatable, Hashable {
/// Name of the tensor. /// Name of the tensor.
public let name: String public let name: String
...@@ -38,9 +37,10 @@ public struct Tensor { ...@@ -38,9 +37,10 @@ public struct Tensor {
/// - Parameters: /// - Parameters:
/// - name: Name of the tensor. /// - name: Name of the tensor.
/// - dataType: Data type of the tensor. /// - dataType: Data type of the tensor.
/// - shape: Shape of the tensor.
/// - data: Data in the input tensor. /// - data: Data in the input tensor.
/// - quantizationParameters Quantization parameters for the tensor if using a quantized model. /// - quantizationParameters Quantization parameters for the tensor if using a quantized model.
/// The default is `nil`. /// Default is `nil`.
init( init(
name: String, name: String,
dataType: TensorDataType, dataType: TensorDataType,
...@@ -57,7 +57,7 @@ public struct Tensor { ...@@ -57,7 +57,7 @@ public struct Tensor {
} }
/// Supported TensorFlow Lite tensor data types. /// Supported TensorFlow Lite tensor data types.
public enum TensorDataType: Equatable { public enum TensorDataType: Equatable, Hashable {
/// Boolean. /// Boolean.
case bool case bool
/// 8-bit unsigned integer. /// 8-bit unsigned integer.
...@@ -102,7 +102,7 @@ public enum TensorDataType: Equatable { ...@@ -102,7 +102,7 @@ public enum TensorDataType: Equatable {
} }
/// The shape of a TensorFlow Lite tensor. /// The shape of a TensorFlow Lite tensor.
public struct TensorShape { public struct TensorShape: Equatable, Hashable {
/// The number of dimensions of the tensor. /// The number of dimensions of the tensor.
public let rank: Int public let rank: Int
......
...@@ -33,11 +33,3 @@ class QuantizationParametersTests: XCTestCase { ...@@ -33,11 +33,3 @@ class QuantizationParametersTests: XCTestCase {
XCTAssertNotEqual(parameters2, parameters3) XCTAssertNotEqual(parameters2, parameters3)
} }
} }
// MARK: - Extensions
extension QuantizationParameters: Equatable {
public static func == (lhs: QuantizationParameters, rhs: QuantizationParameters) -> Bool {
return lhs.scale == rhs.scale && lhs.zeroPoint == rhs.zeroPoint
}
}
...@@ -39,6 +39,38 @@ class TensorTests: XCTestCase { ...@@ -39,6 +39,38 @@ class TensorTests: XCTestCase {
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters) XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
} }
func testTensor_Equatable() {
let name = "Tensor"
let dataType: TensorDataType = .uInt8
let shape = TensorShape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let tensor1 = Tensor(
name: name,
dataType: dataType,
shape: shape,
data: data,
quantizationParameters: quantizationParameters
)
var tensor2 = Tensor(
name: name,
dataType: dataType,
shape: shape,
data: data,
quantizationParameters: quantizationParameters
)
XCTAssertEqual(tensor1, tensor2)
tensor2 = Tensor(
name: "Tensor2",
dataType: dataType,
shape: shape,
data: data,
quantizationParameters: quantizationParameters
)
XCTAssertNotEqual(tensor1, tensor2)
}
// MARK: - TensorShape // MARK: - TensorShape
func testTensorShape_InitWithArray() { func testTensorShape_InitWithArray() {
...@@ -58,6 +90,15 @@ class TensorTests: XCTestCase { ...@@ -58,6 +90,15 @@ class TensorTests: XCTestCase {
XCTAssertEqual(shape.rank, Constant.dimensions.count) XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions) XCTAssertEqual(shape.dimensions, Constant.dimensions)
} }
func testTensorShape_Equatable() {
let shape1 = TensorShape(2, 2, 3)
var shape2: TensorShape = [2, 2, 3]
XCTAssertEqual(shape1, shape2)
shape2 = [2, 2, 4]
XCTAssertNotEqual(shape1, shape2)
}
} }
// MARK: - Constants // MARK: - Constants
...@@ -66,18 +107,3 @@ private enum Constant { ...@@ -66,18 +107,3 @@ private enum Constant {
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]. /// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
static let dimensions = [2, 2, 3] static let dimensions = [2, 2, 3]
} }
// MARK: - Extensions
extension TensorShape: Equatable {
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
}
}
extension Tensor: Equatable {
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册