提交 66ad725e 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1. add optimize parameter to load function. 2. add transpose operation if the...

1. add optimize parameter to load function. 2. add transpose operation if the input of fetch op is not [0, 2, 3, 1]. (#1651)
上级 9e63d954
......@@ -57,5 +57,4 @@ kernel void FUNC(fetch, 1or2, P)(texture2d_array<P, access::read> inTexture [[te
output[gid.y * input_width + gid.x] = float4(input);
}
#endif
......@@ -72,10 +72,20 @@ import Foundation
return success
}
/// load 模型, 返回 true 可进行预测,公共方法,保证线程安全
///
/// - Returns: load 成功或失败
@objc public func load(optimize: Bool) -> Bool {
Runner.loadLock.lock()
let success = unSafeLoad(optimize: optimize)
Runner.loadLock.unlock()
return success
}
/// load 模型, 返回 true 可进行预测,不保证线程安全
///
/// - Returns: load 成功或失败
private func unSafeLoad() -> Bool {
private func unSafeLoad(optimize: Bool = true) -> Bool {
guard let inDevice = device, let inQueue = queue else {
print(" paddle mobile gpu load error, need MTLCommandQueue")
return false
......@@ -89,15 +99,14 @@ import Foundation
}
do {
if let inParamPointer = net.paramPointer, let inModelPointer = net.modelPointer {
guard net.paramSize > 0 && net.modelSize > 0 else {
print(" load from memory param size or model size can't 0 ")
return false
}
program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize,modePointer:inModelPointer,modelSize:net.modelSize)
program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize, modePointer: inModelPointer, modelSize: net.modelSize, optimize: optimize)
} else if let inModelPath = net.modelPath, let inParamPath = net.paramPath {
program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath)
program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath, optimize: optimize)
} else {
print(" model pointer or model file path need be specified")
return false
......
......@@ -16,11 +16,11 @@ import Foundation
//import SwiftProtobuf
protocol Loaderable {
func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program
func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program
func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool) throws -> Program
func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool) throws -> Program
}
public class Loader<P: PrecisionProtocol>: Loaderable{
public class Loader<P: PrecisionProtocol>: Loaderable {
class ParaLoader {
let file: UnsafeMutablePointer<FILE>
let fileSize: Int
......@@ -186,7 +186,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
}
}
public init(){}
private func loadModelandParam(_ device:MTLDevice,_ modelData:Data, _ paraLoaderPointer:ParaLoaderWithPointer?, _ paraLoader:ParaLoader?) throws -> Program {
private func loadModelandParam(_ device: MTLDevice, _ modelData: Data, _ paraLoaderPointer: ParaLoaderWithPointer?, _ paraLoader: ParaLoader?, _ optimize: Bool = true) throws -> Program {
do {
/// swift protobuf serialized Data to instance class
// let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
......@@ -196,7 +196,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
let protoProgram = try ProgramDesc.init(data: (modelData as NSData) as Data)
let originProgramDesc = PMProgramDesc.init(protoProgram: protoProgram)
let programDesc = ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc)
let programDesc = optimize ? ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc) : originProgramDesc
// let programDesc = PMProgramDesc.init(protoProgram: protoProgram)
if GlobalConfig.shared.debug {
......@@ -281,20 +281,20 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
throw PaddleMobileError.loaderError(message: "protobuf decoder error")
}
}
public func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program {
public func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool = true) throws -> Program {
let modelData = Data.init(bytes:modePointer, count:modelSize)
guard let paraLoader = try? ParaLoaderWithPointer.init(pPointer: paramPointer,pSize: paramSize) else {
throw PaddleMobileError.loaderError(message: "load para error")
}
do {
let program = try loadModelandParam(device,modelData,paraLoader,nil)
let program = try loadModelandParam(device, modelData, paraLoader, nil, optimize)
return program
} catch let error {
throw error
}
}
public func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program {
public func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool = true) throws -> Program {
guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else {
throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !")
}
......@@ -303,7 +303,7 @@ public class Loader<P: PrecisionProtocol>: Loaderable{
}
do {
let program = try loadModelandParam(device,modelData,nil,paraLoader)
let program = try loadModelandParam(device, modelData, nil, paraLoader, optimize)
return program
} catch let error {
throw error
......
......@@ -15,6 +15,9 @@
import Foundation
class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
var expectedTranspose: [Int]?
var device: MTLDevice?
var initContext: InitContext?
required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) throws {
param.output.initBuffer(device: device)
......@@ -25,6 +28,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
switch param.input.tensorDim.cout() {
case 1, 2:
super.init(device: device, inFunctionName: "fetch_1or2_half", initContext: initContext)
case 4:
expectedTranspose = [0, 2, 3, 1]
super.init(device: device, inFunctionName: "fetch_half", initContext: initContext)
default:
fatalError(" not support ")
}
......@@ -38,6 +44,9 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
switch param.input.tensorDim.cout() {
case 1, 2:
super.init(device: device, inFunctionName: "fetch_1or2_float", initContext: initContext)
case 4:
expectedTranspose = [0, 2, 3, 1]
super.init(device: device, inFunctionName: "fetch_float", initContext: initContext)
default:
fatalError(" not support ")
}
......@@ -47,13 +56,25 @@ class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
} else {
fatalError(" not support ")
}
self.device = device
self.initContext = initContext
}
func compute(commandBuffer: MTLCommandBuffer, param: FetchParam<P>) throws {
var input = param.input
if let expectedTranspose = expectedTranspose {
if param.input.transpose != expectedTranspose {
if let device = device, let initContext = initContext, let transposedInput = encodeTransposeInput(input: param.input, toTranspose: expectedTranspose, commandBuffer: commandBuffer, device: device, initContext: initContext) {
input = transposedInput
} else {
print("input transpose failed in slice kernel")
}
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(input.metalTexture, index: 0)
encoder.setBuffer(param.output.resultBuffer!, offset: 0, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture)
encoder.endEncoding()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册