未验证 提交 6417449f 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1. add optimize parameter to load function. 2. add transpose operation if the...

1. add optimize parameter to load function. 2. add transpose operation if the input of fetch op is not [0, 2, 3, 1]. (#1651)
上级 9be69f95
...@@ -57,5 +57,4 @@ kernel void FUNC(fetch, 1or2, P)(texture2d_array<P, access::read> inTexture [[te ...@@ -57,5 +57,4 @@ kernel void FUNC(fetch, 1or2, P)(texture2d_array<P, access::read> inTexture [[te
output[gid.y * input_width + gid.x] = float4(input); output[gid.y * input_width + gid.x] = float4(input);
} }
#endif #endif
...@@ -72,10 +72,20 @@ import Foundation ...@@ -72,10 +72,20 @@ import Foundation
return success return success
} }
/// load 模型, 返回 true 可进行预测,公共方法,保证线程安全
///
/// - Returns: load 成功或失败
@objc public func load(optimize: Bool) -> Bool {
Runner.loadLock.lock()
let success = unSafeLoad(optimize: optimize)
Runner.loadLock.unlock()
return success
}
/// load 模型, 返回 true 可进行预测,不保证线程安全 /// load 模型, 返回 true 可进行预测,不保证线程安全
/// ///
/// - Returns: load 成功或失败 /// - Returns: load 成功或失败
private func unSafeLoad() -> Bool { private func unSafeLoad(optimize: Bool = true) -> Bool {
guard let inDevice = device, let inQueue = queue else { guard let inDevice = device, let inQueue = queue else {
print(" paddle mobile gpu load error, need MTLCommandQueue") print(" paddle mobile gpu load error, need MTLCommandQueue")
return false return false
...@@ -89,15 +99,14 @@ import Foundation ...@@ -89,15 +99,14 @@ import Foundation
} }
do { do {
if let inParamPointer = net.paramPointer, let inModelPointer = net.modelPointer { if let inParamPointer = net.paramPointer, let inModelPointer = net.modelPointer {
guard net.paramSize > 0 && net.modelSize > 0 else { guard net.paramSize > 0 && net.modelSize > 0 else {
print(" load from memory param size or model size can't 0 ") print(" load from memory param size or model size can't 0 ")
return false return false
} }
program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize,modePointer:inModelPointer,modelSize:net.modelSize) program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize, modePointer: inModelPointer, modelSize: net.modelSize, optimize: optimize)
} else if let inModelPath = net.modelPath, let inParamPath = net.paramPath { } else if let inModelPath = net.modelPath, let inParamPath = net.paramPath {
program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath) program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath, optimize: optimize)
} else { } else {
print(" model pointer or model file path need be specified") print(" model pointer or model file path need be specified")
return false return false
......
...@@ -16,11 +16,11 @@ import Foundation ...@@ -16,11 +16,11 @@ import Foundation
//import SwiftProtobuf //import SwiftProtobuf
protocol Loaderable { protocol Loaderable {
func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool) throws -> Program
func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool) throws -> Program
} }
public class Loader<P: PrecisionProtocol>: Loaderable{ public class Loader<P: PrecisionProtocol>: Loaderable {
class ParaLoader { class ParaLoader {
let file: UnsafeMutablePointer<FILE> let file: UnsafeMutablePointer<FILE>
let fileSize: Int let fileSize: Int
...@@ -186,7 +186,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{ ...@@ -186,7 +186,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
} }
} }
public init(){} public init(){}
private func loadModelandParam(_ device:MTLDevice,_ modelData:Data, _ paraLoaderPointer:ParaLoaderWithPointer?, _ paraLoader:ParaLoader?) throws -> Program { private func loadModelandParam(_ device: MTLDevice, _ modelData: Data, _ paraLoaderPointer: ParaLoaderWithPointer?, _ paraLoader: ParaLoader?, _ optimize: Bool = true) throws -> Program {
do { do {
/// swift protobuf serialized Data to instance class /// swift protobuf serialized Data to instance class
// let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( // let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
...@@ -196,7 +196,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{ ...@@ -196,7 +196,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
let protoProgram = try ProgramDesc.init(data: (modelData as NSData) as Data) let protoProgram = try ProgramDesc.init(data: (modelData as NSData) as Data)
let originProgramDesc = PMProgramDesc.init(protoProgram: protoProgram) let originProgramDesc = PMProgramDesc.init(protoProgram: protoProgram)
let programDesc = ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc) let programDesc = optimize ? ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc) : originProgramDesc
// let programDesc = PMProgramDesc.init(protoProgram: protoProgram) // let programDesc = PMProgramDesc.init(protoProgram: protoProgram)
if GlobalConfig.shared.debug { if GlobalConfig.shared.debug {
...@@ -281,20 +281,20 @@ public class Loader<P: PrecisionProtocol>: Loaderable{ ...@@ -281,20 +281,20 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
throw PaddleMobileError.loaderError(message: "protobuf decoder error") throw PaddleMobileError.loaderError(message: "protobuf decoder error")
} }
} }
public func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program { public func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool = true) throws -> Program {
let modelData = Data.init(bytes:modePointer, count:modelSize) let modelData = Data.init(bytes:modePointer, count:modelSize)
guard let paraLoader = try? ParaLoaderWithPointer.init(pPointer: paramPointer,pSize: paramSize) else { guard let paraLoader = try? ParaLoaderWithPointer.init(pPointer: paramPointer,pSize: paramSize) else {
throw PaddleMobileError.loaderError(message: "load para error") throw PaddleMobileError.loaderError(message: "load para error")
} }
do { do {
let program = try loadModelandParam(device,modelData,paraLoader,nil) let program = try loadModelandParam(device, modelData, paraLoader, nil, optimize)
return program return program
} catch let error { } catch let error {
throw error throw error
} }
} }
public func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program { public func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool = true) throws -> Program {
guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else { guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else {
throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !") throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !")
} }
...@@ -303,7 +303,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{ ...@@ -303,7 +303,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
} }
do { do {
let program = try loadModelandParam(device,modelData,nil,paraLoader) let program = try loadModelandParam(device, modelData, nil, paraLoader, optimize)
return program return program
} catch let error { } catch let error {
throw error throw error
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
import Foundation import Foundation
class FetchKernel<P: PrecisionProtocol>: Kernel, Computable { class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
var expectedTranspose: [Int]?
var device: MTLDevice?
var initContext: InitContext?
required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) throws { required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) throws {
param.output.initBuffer(device: device) param.output.initBuffer(device: device)
...@@ -25,6 +28,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -25,6 +28,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
switch param.input.tensorDim.cout() { switch param.input.tensorDim.cout() {
case 1, 2: case 1, 2:
super.init(device: device, inFunctionName: "fetch_1or2_half", initContext: initContext) super.init(device: device, inFunctionName: "fetch_1or2_half", initContext: initContext)
case 4:
expectedTranspose = [0, 2, 3, 1]
super.init(device: device, inFunctionName: "fetch_half", initContext: initContext)
default: default:
fatalError(" not support ") fatalError(" not support ")
} }
...@@ -38,6 +44,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -38,6 +44,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
switch param.input.tensorDim.cout() { switch param.input.tensorDim.cout() {
case 1, 2: case 1, 2:
super.init(device: device, inFunctionName: "fetch_1or2_float", initContext: initContext) super.init(device: device, inFunctionName: "fetch_1or2_float", initContext: initContext)
case 4:
expectedTranspose = [0, 2, 3, 1]
super.init(device: device, inFunctionName: "fetch_float", initContext: initContext)
default: default:
fatalError(" not support ") fatalError(" not support ")
} }
...@@ -47,13 +56,25 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -47,13 +56,25 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
} else { } else {
fatalError(" not support ") fatalError(" not support ")
} }
self.device = device
self.initContext = initContext
} }
func compute(commandBuffer: MTLCommandBuffer, param: FetchParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: FetchParam<P>) throws {
var input = param.input
if let expectedTranspose = expectedTranspose {
if param.input.transpose != expectedTranspose {
if let device = device, let initContext = initContext, let transposedInput = encodeTransposeInput(input: param.input, toTranspose: expectedTranspose, commandBuffer: commandBuffer, device: device, initContext: initContext) {
input = transposedInput
} else {
print("input transpose failed in slice kernel")
}
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
} }
encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(input.metalTexture, index: 0)
encoder.setBuffer(param.output.resultBuffer!, offset: 0, index: 0) encoder.setBuffer(param.output.resultBuffer!, offset: 0, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture)
encoder.endEncoding() encoder.endEncoding()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册