From 6417449f3556802ee1beced5a8107da549a8f940 Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Wed, 22 May 2019 18:42:31 +0800 Subject: [PATCH] 1. add optimize parameter to load function. 2. add transpose operation if the input of fetch op is not [0, 2, 3, 1]. (#1651) --- .../FetchKernel.inc.metal | 1 - .../paddle-mobile/API/Runner.swift | 17 ++++++++++---- .../paddle-mobile/Src/Framework/Loader.swift | 18 +++++++-------- .../Src/Operators/Kernels/FetchKernel.swift | 23 ++++++++++++++++++- 4 files changed, 44 insertions(+), 15 deletions(-) diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/FetchKernel.inc.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/FetchKernel.inc.metal index 114aa15664..0a3f9b8fdd 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/FetchKernel.inc.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/FetchKernel.inc.metal @@ -57,5 +57,4 @@ kernel void FUNC(fetch, 1or2, P)(texture2d_array inTexture [[te output[gid.y * input_width + gid.x] = float4(input); } - #endif diff --git a/metal/paddle-mobile/paddle-mobile/API/Runner.swift b/metal/paddle-mobile/paddle-mobile/API/Runner.swift index 60de2a2169..b4f3c686f9 100644 --- a/metal/paddle-mobile/paddle-mobile/API/Runner.swift +++ b/metal/paddle-mobile/paddle-mobile/API/Runner.swift @@ -72,10 +72,20 @@ import Foundation return success } + /// load 模型, 返回 true 可进行预测,公共方法,保证线程安全 + /// + /// - Returns: load 成功或失败 + @objc public func load(optimize: Bool) -> Bool { + Runner.loadLock.lock() + let success = unSafeLoad(optimize: optimize) + Runner.loadLock.unlock() + return success + } + /// load 模型, 返回 true 可进行预测,不保证线程安全 /// /// - Returns: load 成功或失败 - private func unSafeLoad() -> Bool { + private func unSafeLoad(optimize: Bool = true) -> Bool { guard let inDevice = device, let inQueue = queue else { print(" paddle mobile gpu load error, need MTLCommandQueue") return false @@ -89,15 +99,14 @@ import Foundation } do { - if let inParamPointer = net.paramPointer, let inModelPointer = net.modelPointer { guard net.paramSize > 0 && net.modelSize > 0 else { print(" load from memory param size or model size can't 0 ") return false } - program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize,modePointer:inModelPointer,modelSize:net.modelSize) + program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize, modePointer: inModelPointer, modelSize: net.modelSize, optimize: optimize) } else if let inModelPath = net.modelPath, let inParamPath = net.paramPath { - program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath) + program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath, optimize: optimize) } else { print(" model pointer or model file path need be specified") return false diff --git a/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift b/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift index 664f2dfff9..ed11667ef7 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Framework/Loader.swift @@ -16,11 +16,11 @@ import Foundation //import SwiftProtobuf protocol Loaderable { - func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program - func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program + func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool) throws -> Program + func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool) throws -> Program } -public class Loader: Loaderable{ +public class Loader: Loaderable { class ParaLoader { let file: UnsafeMutablePointer let fileSize: Int @@ -186,7 +186,7 @@ public class Loader: Loaderable{ } } public init(){} - private func loadModelandParam(_ device:MTLDevice,_ modelData:Data, _ paraLoaderPointer:ParaLoaderWithPointer?, _ paraLoader:ParaLoader?) throws -> Program { + private func loadModelandParam(_ device: MTLDevice, _ modelData: Data, _ paraLoaderPointer: ParaLoaderWithPointer?, _ paraLoader: ParaLoader?, _ optimize: Bool = true) throws -> Program { do { /// swift protobuf serialized Data to instance class // let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( @@ -196,7 +196,7 @@ public class Loader: Loaderable{ let protoProgram = try ProgramDesc.init(data: (modelData as NSData) as Data) let originProgramDesc = PMProgramDesc.init(protoProgram: protoProgram) - let programDesc = ProgramOptimize

.init().optimize(originProgramDesc: originProgramDesc) + let programDesc = optimize ? ProgramOptimize

.init().optimize(originProgramDesc: originProgramDesc) : originProgramDesc // let programDesc = PMProgramDesc.init(protoProgram: protoProgram) if GlobalConfig.shared.debug { @@ -281,20 +281,20 @@ public class Loader: Loaderable{ throw PaddleMobileError.loaderError(message: "protobuf decoder error") } } - public func load(device:MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) throws -> Program { + public func load(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize: Int, modePointer: UnsafeMutableRawPointer, modelSize: Int, optimize: Bool = true) throws -> Program { let modelData = Data.init(bytes:modePointer, count:modelSize) guard let paraLoader = try? ParaLoaderWithPointer.init(pPointer: paramPointer,pSize: paramSize) else { throw PaddleMobileError.loaderError(message: "load para error") } do { - let program = try loadModelandParam(device,modelData,paraLoader,nil) + let program = try loadModelandParam(device, modelData, paraLoader, nil, optimize) return program } catch let error { throw error } } - public func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program { + public func load(device: MTLDevice, modelPath: String, paraPath: String, optimize: Bool = true) throws -> Program { guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else { throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !") } @@ -303,7 +303,7 @@ public class Loader: Loaderable{ } do { - let program = try loadModelandParam(device,modelData,nil,paraLoader) + let program = try loadModelandParam(device, modelData, nil, paraLoader, optimize) return program } catch let error { throw error diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FetchKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FetchKernel.swift index 5cfd597de1..dde2353036 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FetchKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FetchKernel.swift @@ -15,6 +15,9 @@ import Foundation class FetchKernel: Kernel, Computable { + var expectedTranspose: [Int]? + var device: MTLDevice? + var initContext: InitContext? required init(device: MTLDevice, param: FetchParam

, initContext: InitContext) throws { param.output.initBuffer(device: device) @@ -25,6 +28,9 @@ class FetchKernel: Kernel, Computable { switch param.input.tensorDim.cout() { case 1, 2: super.init(device: device, inFunctionName: "fetch_1or2_half", initContext: initContext) + case 4: + expectedTranspose = [0, 2, 3, 1] + super.init(device: device, inFunctionName: "fetch_half", initContext: initContext) default: fatalError(" not support ") } @@ -38,6 +44,9 @@ class FetchKernel: Kernel, Computable { switch param.input.tensorDim.cout() { case 1, 2: super.init(device: device, inFunctionName: "fetch_1or2_float", initContext: initContext) + case 4: + expectedTranspose = [0, 2, 3, 1] + super.init(device: device, inFunctionName: "fetch_float", initContext: initContext) default: fatalError(" not support ") } @@ -47,13 +56,25 @@ class FetchKernel: Kernel, Computable { } else { fatalError(" not support ") } + self.device = device + self.initContext = initContext } func compute(commandBuffer: MTLCommandBuffer, param: FetchParam

) throws { + var input = param.input + if let expectedTranspose = expectedTranspose { + if param.input.transpose != expectedTranspose { + if let device = device, let initContext = initContext, let transposedInput = encodeTransposeInput(input: param.input, toTranspose: expectedTranspose, commandBuffer: commandBuffer, device: device, initContext: initContext) { + input = transposedInput + } else { + print("input transpose failed in slice kernel") + } + } + } guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encode is nil") } - encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(input.metalTexture, index: 0) encoder.setBuffer(param.output.resultBuffer!, offset: 0, index: 0) encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture) encoder.endEncoding() -- GitLab