提交 e5ab9b35 编写于 作者: L liuruilong

make pipline optional

上级 7d5d941e
...@@ -56,10 +56,10 @@ import Foundation ...@@ -56,10 +56,10 @@ import Foundation
/// 输入维度,按照 n h w c 方式传入 /// 输入维度,按照 n h w c 方式传入
@objc public var inputDim: Dim = Dim.init(inDim: []) @objc public var inputDim: Dim = Dim.init(inDim: [])
/// 是否使用 MetalPerformanceShaders 进行运算 /// 是否使用 MetalPerformanceShaders 进行运算, 运算精度为 32 位时不支持开启 MPS
@objc public var useMPS: Bool = false @objc public var useMPS: Bool = false
/// 模型精度 - 当使用模型精度为 Float 16 时 不要开启 useMPS, 暂不支持 /// 模型精度
@objc public var paramPrecision: Precision = .Float32 @objc public var paramPrecision: Precision = .Float32
@objc public init(device: MTLDevice, inParamPointer: UnsafeMutableRawPointer, inParamSize:Int, inModelPointer: UnsafeMutableRawPointer, inModelSize: Int) { @objc public init(device: MTLDevice, inParamPointer: UnsafeMutableRawPointer, inParamSize:Int, inModelPointer: UnsafeMutableRawPointer, inModelSize: Int) {
......
...@@ -38,11 +38,21 @@ protocol KernelProtocol { ...@@ -38,11 +38,21 @@ protocol KernelProtocol {
} }
@objc open class Kernel: NSObject{ @objc open class Kernel: NSObject{
let pipline: MTLComputePipelineState
let functionName: String private var _pipline: MTLComputePipelineState? = nil
public init(device: MTLDevice, inFunctionName: String, usePaddleMobileLib: Bool = false, initContext: InitContext) {
pipline = device.pipeLine(funcName: inFunctionName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) var pipline: MTLComputePipelineState {
get {
return _pipline ?! " pipeline can't be nil "
}
}
let functionName: String?
public init(device: MTLDevice, inFunctionName: String?, usePaddleMobileLib: Bool = false, initContext: InitContext) {
functionName = inFunctionName functionName = inFunctionName
if let funcName = inFunctionName {
_pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
}
} }
} }
...@@ -104,7 +114,7 @@ open class BufferToTextureKernel: Kernel { ...@@ -104,7 +114,7 @@ open class BufferToTextureKernel: Kernel {
@objc open class CusomKernel: Kernel { @objc open class CusomKernel: Kernel {
public let outputTexture: MTLTexture public let outputTexture: MTLTexture
public init(device: MTLDevice, inFunctionName: String, outputDim: Shape, metalLoadModel: MetalLoadMode, metalLibPath: String?) { public init(device: MTLDevice, inFunctionName: String?, outputDim: Shape, metalLoadModel: MetalLoadMode, metalLibPath: String?) {
let textureDesc = MTLTextureDescriptor.init() let textureDesc = MTLTextureDescriptor.init()
textureDesc.textureType = .type2D textureDesc.textureType = .type2D
textureDesc.width = outputDim.width textureDesc.width = outputDim.width
......
...@@ -145,7 +145,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -145,7 +145,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
convDic[key] = conv convDic[key] = conv
imageDic[identifyingKey + "_input"] = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1]) imageDic[identifyingKey + "_input"] = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
imageDic[identifyingKey + "_output"] = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1]) imageDic[identifyingKey + "_output"] = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
super.init(device: device, inFunctionName: "place_holder", initContext: initContext) super.init(device: device, inFunctionName: nil, initContext: initContext)
return return
} }
} }
......
...@@ -20,9 +20,9 @@ public class ScaleKernel: CusomKernel { ...@@ -20,9 +20,9 @@ public class ScaleKernel: CusomKernel {
public init(device: MTLDevice, shape: Shape, metalLoadMode: MetalLoadMode, metalLibPath: String?) { public init(device: MTLDevice, shape: Shape, metalLoadMode: MetalLoadMode, metalLibPath: String?) {
lanczos = MPSImageLanczosScale(device: device) lanczos = MPSImageLanczosScale(device: device)
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "scale", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath) super.init(device: device, inFunctionName: nil, outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath) super.init(device: device, inFunctionName: nil, outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
......
...@@ -23,15 +23,6 @@ struct Texture2DTo2DArrayParam { ...@@ -23,15 +23,6 @@ struct Texture2DTo2DArrayParam {
} }
class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{ class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: FeedParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.mtlTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input.mtlTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
...@@ -42,6 +33,15 @@ class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -42,6 +33,15 @@ class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
} else { } else {
fatalError() fatalError()
} }
}
func compute(commandBuffer: MTLCommandBuffer, param: FeedParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.mtlTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input.mtlTexture)
encoder.endEncoding()
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册