提交 e5ab9b35 编写于 作者: L liuruilong

make pipline optional

上级 7d5d941e
......@@ -56,10 +56,10 @@ import Foundation
/// 输入维度,按照 n h w c 方式传入
@objc public var inputDim: Dim = Dim.init(inDim: [])
/// 是否使用 MetalPerformanceShaders 进行运算
/// 是否使用 MetalPerformanceShaders 进行运算, 运算精度为 32 位时不支持开启 MPS
@objc public var useMPS: Bool = false
/// 模型精度 - 当使用模型精度为 Float 16 时 不要开启 useMPS, 暂不支持
/// 模型精度
@objc public var paramPrecision: Precision = .Float32
@objc public init(device: MTLDevice, inParamPointer: UnsafeMutableRawPointer, inParamSize:Int, inModelPointer: UnsafeMutableRawPointer, inModelSize: Int) {
......
......@@ -38,11 +38,21 @@ protocol KernelProtocol {
}
@objc open class Kernel: NSObject{
let pipline: MTLComputePipelineState
let functionName: String
public init(device: MTLDevice, inFunctionName: String, usePaddleMobileLib: Bool = false, initContext: InitContext) {
pipline = device.pipeLine(funcName: inFunctionName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
private var _pipline: MTLComputePipelineState? = nil
var pipline: MTLComputePipelineState {
get {
return _pipline ?! " pipeline can't be nil "
}
}
let functionName: String?
public init(device: MTLDevice, inFunctionName: String?, usePaddleMobileLib: Bool = false, initContext: InitContext) {
functionName = inFunctionName
if let funcName = inFunctionName {
_pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
}
}
}
......@@ -104,7 +114,7 @@ open class BufferToTextureKernel: Kernel {
@objc open class CusomKernel: Kernel {
public let outputTexture: MTLTexture
public init(device: MTLDevice, inFunctionName: String, outputDim: Shape, metalLoadModel: MetalLoadMode, metalLibPath: String?) {
public init(device: MTLDevice, inFunctionName: String?, outputDim: Shape, metalLoadModel: MetalLoadMode, metalLibPath: String?) {
let textureDesc = MTLTextureDescriptor.init()
textureDesc.textureType = .type2D
textureDesc.width = outputDim.width
......
......@@ -145,7 +145,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
convDic[key] = conv
imageDic[identifyingKey + "_input"] = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
imageDic[identifyingKey + "_output"] = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
super.init(device: device, inFunctionName: "place_holder", initContext: initContext)
super.init(device: device, inFunctionName: nil, initContext: initContext)
return
}
}
......
......@@ -20,9 +20,9 @@ public class ScaleKernel: CusomKernel {
public init(device: MTLDevice, shape: Shape, metalLoadMode: MetalLoadMode, metalLibPath: String?) {
lanczos = MPSImageLanczosScale(device: device)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "scale", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
super.init(device: device, inFunctionName: nil, outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
super.init(device: device, inFunctionName: nil, outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else {
fatalError(" unsupport ")
}
......
......@@ -23,15 +23,6 @@ struct Texture2DTo2DArrayParam {
}
class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: FeedParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.mtlTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input.mtlTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
......@@ -42,6 +33,15 @@ class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: FeedParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.mtlTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input.mtlTexture)
encoder.endEncoding()
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册