diff --git a/metal/paddle-mobile/paddle-mobile/API/Runner.swift b/metal/paddle-mobile/paddle-mobile/API/Runner.swift index 13b3e7cec3fa61dcc1a57413ade4028525c417ff..60de2a2169ef5024de3802e2282945bcbaf5a248 100644 --- a/metal/paddle-mobile/paddle-mobile/API/Runner.swift +++ b/metal/paddle-mobile/paddle-mobile/API/Runner.swift @@ -131,8 +131,12 @@ import Foundation /// - completion: 结果回调, 当 success 为 true 时 result 不为 nil @objc public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ result: [ResultHolder]?) -> Void) { do { - - try self.executor?.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (success, res) in + guard let executor = self.executor else { + print("executor is empty") + completion(false, nil) + return + } + try executor.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (success, res) in if success, let SSelf = self, let res = res { let result = SSelf.net.fetchResult(paddleMobileRes: res) if result.count > 0 { diff --git a/metal/paddle-mobile/paddle-mobile/Src/Framework/Executor.swift b/metal/paddle-mobile/paddle-mobile/Src/Framework/Executor.swift index 478d19600e74960b73f26ca267e446882f6c5c47..8ddcf5f555be0856f318659bacbe17d0a9fcd8f4 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Framework/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Framework/Executor.swift @@ -78,9 +78,11 @@ public class Executor: Executorable{ public func predict(input: MTLTexture, dim: Dim, completionHandle: @escaping ( _ success: Bool, _ result: [GPUResultHolder]?) -> Void, preProcessKernle: CusomKernel? = nil, except: Int = 0) throws { inflightSemaphore.wait() guard isValid else { + inflightSemaphore.signal() throw PaddleMobileError.predictError(message: "Executor is cleared and invalid") } guard let buffer = queue.makeCommandBuffer() else { + inflightSemaphore.signal() throw PaddleMobileError.predictError(message: "CommandBuffer is nil") } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift b/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift index afd75e82d19d431e72f25f2a379fa35328dd98a3..588ae5de88784a0b15fc8465876d453fbc1877f2 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift @@ -45,11 +45,11 @@ public class Loader: Loaderable{ } func pointerReader(type: T.Type) -> T { - let ptr = UnsafeMutablePointer.allocate(capacity: MemoryLayout.size) + let ptr = UnsafeMutablePointer.allocate(capacity: 1) fread(ptr, 1, MemoryLayout.size, file) nowIndex += MemoryLayout.size let pointee = ptr.pointee - ptr.deinitialize(count: MemoryLayout.size) + ptr.deinitialize(count: 1) ptr.deallocate() return pointee } @@ -65,10 +65,48 @@ public class Loader: Loaderable{ let _ = pointerReader(type: UInt32.self) - let tensorDescSize = pointerReader(type: Int32.self) + // 读取张量信息 + let tensorDescSize = Int(pointerReader(type: Int32.self)) - fseek(file, Int(tensorDescSize), SEEK_CUR) - nowIndex += Int(tensorDescSize) + if GlobalConfig.shared.debug { + let tensorDescCharArray = UnsafeMutablePointer.allocate(capacity: tensorDescSize) + for i in 0...size * tensorDescSize) + var tensorDescFromParams: VarType_TensorDesc? + do { + tensorDescFromParams = try VarType_TensorDesc.init(data: data) + } catch let error { + print("\(error)") + } + tensorDescCharArray.deinitialize(count: tensorDescSize) + tensorDescCharArray.deallocate() + repeat { + guard let tensorDescFromParams = tensorDescFromParams, let dimsArrayFromParams = tensorDescFromParams.dimsArray else { + print("tensorDescFromParams is nil") + break + } + if tensorDescFromParams.dimsArray_Count != dimsArrayFromParams.count { + print("dimsArray_Count not equal to tensorDescFromParams.dimsArray.count") + break + } + if tensorDescFromParams.dimsArray_Count != tensor.tensorDim.cout() { + print("dimsArray_Count not equal to tensor.tensorDim.cout()") + break + } + for i in 0...size * tensorDescSize, SEEK_CUR) + } + nowIndex += MemoryLayout.size * tensorDescSize /* 这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度