提交 ef827162 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1.parse tensor desc info from params file. 2.ensure tensor desc is consistent...

1.parse tensor desc info from params file. 2.ensure tensor desc is consistent between params file and model file. (#1631)
上级 5945175f
......@@ -131,8 +131,12 @@ import Foundation
/// - completion: 结果回调, 当 success 为 true 时 result 不为 nil
@objc public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ result: [ResultHolder]?) -> Void) {
do {
try self.executor?.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (success, res) in
guard let executor = self.executor else {
print("executor is empty")
completion(false, nil)
return
}
try executor.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (success, res) in
if success, let SSelf = self, let res = res {
let result = SSelf.net.fetchResult(paddleMobileRes: res)
if result.count > 0 {
......
......@@ -78,9 +78,11 @@ public class Executor<P: PrecisionProtocol>: Executorable{
public func predict(input: MTLTexture, dim: Dim, completionHandle: @escaping ( _ success: Bool, _ result: [GPUResultHolder]?) -> Void, preProcessKernle: CusomKernel? = nil, except: Int = 0) throws {
inflightSemaphore.wait()
guard isValid else {
inflightSemaphore.signal()
throw PaddleMobileError.predictError(message: "Executor is cleared and invalid")
}
guard let buffer = queue.makeCommandBuffer() else {
inflightSemaphore.signal()
throw PaddleMobileError.predictError(message: "CommandBuffer is nil")
}
......
......@@ -45,11 +45,11 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
}
func pointerReader<T>(type: T.Type) -> T {
let ptr = UnsafeMutablePointer<T>.allocate(capacity: MemoryLayout<T>.size)
let ptr = UnsafeMutablePointer<T>.allocate(capacity: 1)
fread(ptr, 1, MemoryLayout<T>.size, file)
nowIndex += MemoryLayout<T>.size
let pointee = ptr.pointee
ptr.deinitialize(count: MemoryLayout<UInt32>.size)
ptr.deinitialize(count: 1)
ptr.deallocate()
return pointee
}
......@@ -65,10 +65,48 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
let _ = pointerReader(type: UInt32.self)
let tensorDescSize = pointerReader(type: Int32.self)
// 读取张量信息
let tensorDescSize = Int(pointerReader(type: Int32.self))
fseek(file, Int(tensorDescSize), SEEK_CUR)
nowIndex += Int(tensorDescSize)
if GlobalConfig.shared.debug {
let tensorDescCharArray = UnsafeMutablePointer<CChar>.allocate(capacity: tensorDescSize)
for i in 0..<tensorDescSize {
let ch = pointerReader(type: CChar.self)
tensorDescCharArray[i] = ch
}
let data = Data(bytes: tensorDescCharArray, count: MemoryLayout<CChar>.size * tensorDescSize)
var tensorDescFromParams: VarType_TensorDesc?
do {
tensorDescFromParams = try VarType_TensorDesc.init(data: data)
} catch let error {
print("\(error)")
}
tensorDescCharArray.deinitialize(count: tensorDescSize)
tensorDescCharArray.deallocate()
repeat {
guard let tensorDescFromParams = tensorDescFromParams, let dimsArrayFromParams = tensorDescFromParams.dimsArray else {
print("tensorDescFromParams is nil")
break
}
if tensorDescFromParams.dimsArray_Count != dimsArrayFromParams.count {
print("dimsArray_Count not equal to tensorDescFromParams.dimsArray.count")
break
}
if tensorDescFromParams.dimsArray_Count != tensor.tensorDim.cout() {
print("dimsArray_Count not equal to tensor.tensorDim.cout()")
break
}
for i in 0..<dimsArrayFromParams.count {
if dimsArrayFromParams.value(at: i) != tensor.tensorDim[Int(i)] {
print("tensorDescFromParams \(String(describing: tensorDescFromParams.dimsArray)) not equal to tensor.tensorDim \(tensor.tensorDim)")
break
}
}
} while (false)
} else {
fseek(file, MemoryLayout<CChar>.size * tensorDescSize, SEEK_CUR)
}
nowIndex += MemoryLayout<CChar>.size * tensorDescSize
/*
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册