提交 8d11d7c9 编写于 作者: L liuruilong

add new interface

上级 bab922cc
......@@ -14,10 +14,10 @@
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; };
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC27991021341CE5000B6BAD /* Net.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC27990F21341CE5000B6BAD /* Net.swift */; };
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; };
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; };
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; };
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD95021402B610073E130 /* PaddleMobile.swift */; };
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; };
FC8CFEE62135452C0094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE42135452B0094D569 /* genet_params */; };
FC8CFEE72135452C0094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE52135452B0094D569 /* genet_model */; };
......@@ -61,12 +61,12 @@
FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC27990F21341CE5000B6BAD /* Net.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Net.swift; sourceTree = "<group>"; };
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; };
FC27991221343A3A000B6BAD /* CPUCompute.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = "<group>"; };
FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; };
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; };
FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC4FD95021402B610073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = PaddleMobile.swift; path = ../../../../../../Desktop/PaddleMobile.swift; sourceTree = "<group>"; };
FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC8CFEE42135452B0094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = "<group>"; };
FC8CFEE52135452B0094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = "<group>"; };
......@@ -176,11 +176,11 @@
FC8CFED2213519540094D569 /* Net */ = {
isa = PBXGroup;
children = (
FC4FD95021402B610073E130 /* PaddleMobile.swift */,
FC013927210204A3008100E3 /* PreProcessKernel.metal */,
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */,
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */,
FC3C80102133F4AB00D1295E /* MobileNet.swift */,
FC27990F21341CE5000B6BAD /* Net.swift */,
FC27991221343A3A000B6BAD /* CPUCompute.mm */,
FC27991421343A46000B6BAD /* CPUCompute.h */,
FC8CFEE1213524EA0094D569 /* Genet.swift */,
......@@ -342,11 +342,11 @@
files = (
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */,
FC27991021341CE5000B6BAD /* Net.swift in Sources */,
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */,
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */,
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */,
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */,
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
);
......@@ -494,19 +494,19 @@
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution";
CODE_SIGN_STYLE = Manual;
DEVELOPMENT_TEAM = 6T9LLJKSM4;
CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa;
PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS;
PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 4.0;
......@@ -520,19 +520,19 @@
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution";
CODE_SIGN_STYLE = Manual;
DEVELOPMENT_TEAM = 6T9LLJKSM4;
CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa;
PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS;
PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_VERSION = 4.0;
TARGETED_DEVICE_FAMILY = "1,2";
......
......@@ -17,10 +17,6 @@ import paddle_mobile
class Genet: Net {
var program: Program?
var executor: Executor<Float32>?
let except: Int = 0
class GenetPreProccess: CusomKernel {
......
......@@ -28,25 +28,6 @@ class MetalHelper {
textureLoader = MTKTextureLoader.init(device: device)
}
static func scaleTexture(queue: MTLCommandQueue, input: MTLTexture, size:(width: Int, height: Int), complete: @escaping (MTLTexture) -> Void) {
guard let buffer = queue.makeCommandBuffer() else {
fatalError()
}
let scaleKernel = ScaleKernel.init(device: MetalHelper.shared.device, shape: CusomKernel.Shape.init(inWidth: size.width, inHeight: size.height, inChannel: 3))
do {
try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer)
} catch let error {
print(error)
fatalError()
}
buffer.addCompletedHandler { (buffer) in
complete(scaleKernel.outputTexture)
}
buffer.commit()
}
}
......@@ -15,12 +15,10 @@
import Foundation
import paddle_mobile
class MobileNet: Net{
var program: Program?
var executor: Executor<Float32>?
let except: Int = 0
class MobilenetPreProccess: CusomKernel {
......
......@@ -14,13 +14,8 @@
import Foundation
import paddle_mobile
//import
class MobileNet_ssd_hand: Net{
var program: Program?
var executor: Executor<Float32>?
let except: Int = 2
class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) {
......@@ -83,36 +78,11 @@ class MobileNet_ssd_hand: Net{
let modelDir: String
// let paramPointer: UnsafeMutableRawPointer
//
// let paramSize: Int
//
// let modelPointer: UnsafeMutableRawPointer
//
// let modelSize: Int
//
// /**
// * inParamPointer: 参数文件内存地址
// * inParamSize: 参数文件大小(字节数)
// * inModelPointer: 模型文件内存地址
// * inModelSize: 模型文件大小(字节数)
// */
// init(inParamPointer: UnsafeMutableRawPointer, inParamSize: Int, inModelPointer: UnsafeMutableRawPointer, inModelSize: Int) {
// paramPointer = inParamPointer
// paramSize = inParamSize
// modelPointer = inModelPointer
// modelSize = inModelSize
//// fatalError()
// }
init() {
modelPath = Bundle.main.path(forResource: "ssd_hand_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "ssd_hand_params", ofType: nil) ?! "para null"
modelDir = ""
preprocessKernel = MobilenetssdPreProccess.init(device: MetalHelper.shared.device)
// fatalError()
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
import UIKit
import MetalKit
import Foundation
import paddle_mobile
import MetalPerformanceShaders
class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
}
}
protocol Net {
var program: Program? { get set }
var executor: Executor<Float32>? { get set }
var except: Int { get }
var dim: (n: Int, h: Int, w: Int, c: Int) { get }
var modelPath: String { get }
var paramPath: String { get }
var modelDir: String { get }
var preprocessKernel: CusomKernel { get }
func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void)
func resultStr(res: [Float]) -> String
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32]
mutating func load() throws
func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws
mutating func clear()
}
extension Net {
mutating func load() throws {
let queue = MetalHelper.shared.queue
let loader = Loader<Float32>.init()
do {
program = try loader.load(device: MetalHelper.shared.device, modelPath: modelPath, paraPath: paramPath)
executor = try Executor<Float32>.init(inDevice: MetalHelper.shared.device, inQueue: queue, inProgram: program!)
} catch let error {
throw error
}
}
func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws {
guard let inExecutor = executor else {
fatalError(" 请先 load ")
}
try inExecutor.predict(input: inTexture, dim: [dim.n, dim.h, dim.w, dim.c], completionHandle: { (result) in
var resultArr:[Float32] = []
resultArr = self.fetchResult(paddleMobileRes: result)
completion((time: TimeInterval(result.elapsedTime), resultArray: resultArr))
}, preProcessKernle: preprocessKernel, except: except)
}
mutating func clear() {
executor?.clear()
program = nil
executor = nil
}
func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) {
let texture = try? MetalHelper.shared.textureLoader.newTexture(cgImage: image, options: [:]) ?! " texture loader error"
MetalHelper.scaleTexture(queue: MetalHelper.shared.queue, input: texture!, size: (dim.w, dim.h)) { (resTexture) in
getTexture(resTexture)
}
}
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr
}
}
......@@ -25,7 +25,7 @@ let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
enum SupportModel: String{
// case mobilenet = "mobilenet"
// case mobilenet = "mobilenet"
case mobilenet_ssd = "mobilenetssd"
case genet = "genet"
static func supportedModels() -> [SupportModel] {
......@@ -40,6 +40,7 @@ class ViewController: UIViewController {
@IBOutlet weak var elapsedTimeLabel: UILabel!
@IBOutlet weak var modelPickerView: UIPickerView!
@IBOutlet weak var threadPickerView: UIPickerView!
var runnner: Runner!
var selectImage: UIImage?
var modelType: SupportModel = SupportModel.supportedModels()[0]
var toPredictTexture: MTLTexture?
......@@ -56,10 +57,10 @@ class ViewController: UIViewController {
@IBAction func loadAct(_ sender: Any) {
do {
try self.net.load()
} catch let error {
print(error)
if runnner.load() {
print(" load success ! ")
} else {
print(" load error ! ")
}
}
......@@ -71,7 +72,7 @@ class ViewController: UIViewController {
}
@IBAction func clearAct(_ sender: Any) {
net.clear()
runnner.clear()
}
@IBAction func predictAct(_ sender: Any) {
......@@ -79,26 +80,22 @@ class ViewController: UIViewController {
resultTextView.text = "请选择图片 ! "
return
}
do {
let max = 50
let startDate = Date.init()
for i in 0..<max {
try net.predict(inTexture: inTexture) { [weak self] (result) in
guard let sSelf = self else {
fatalError()
}
if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async {
sSelf.resultTextView.text = sSelf.net.resultStr(res: result.resultArray)
sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
}
let max = 50
let startDate = Date.init()
for i in 0..<max {
runnner.predict(texture: inTexture) { [weak self] (success, time, result) in
guard let sSelf = self else {
fatalError()
}
if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async {
sSelf.resultTextView.text = sSelf.net.resultStr(res: result)
sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
}
}
}
} catch let error {
print(error)
}
}
......@@ -111,7 +108,9 @@ class ViewController: UIViewController {
selectImage = UIImage.init(named: "hand.jpg")
selectImageView.image = selectImage
net.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
runnner = Runner.init(inNet: net, commandQueue: MetalHelper.shared.queue, inPlatform: .GPU)
runnner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture
}
}
......@@ -167,7 +166,7 @@ extension ViewController: UIImagePickerControllerDelegate, UINavigationControll
}
sSelf.selectImage = image
sSelf.selectImageView.image = image
sSelf.net.getTexture(image: image.cgImage!, getTexture: { (texture) in
sSelf.runnner.getTexture(image: image.cgImage!, getTexture: { (texture) in
sSelf.toPredictTexture = texture
})
}
......
......@@ -44,6 +44,7 @@
FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; };
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; };
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; };
FC4FD9752140E1DE0073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
......@@ -136,6 +137,7 @@
FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = "<group>"; };
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; };
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; };
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PaddleMobile.swift; sourceTree = "<group>"; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
......@@ -235,10 +237,9 @@
FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = {
isa = PBXGroup;
children = (
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */,
FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */,
FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */,
FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */,
......@@ -263,6 +264,8 @@
FC039B9C20E11CB20081E9F8 /* framework */ = {
isa = PBXGroup;
children = (
FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9D20E11CB20081E9F8 /* Tensor.swift */,
FC039B9E20E11CB20081E9F8 /* Dim.swift */,
FC9D038320E23B01000F735A /* Texture.swift */,
......@@ -524,6 +527,7 @@
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */,
FCA3A1632132A4AC00084FE5 /* ReshapeKernel.metal in Sources */,
FC4FD9752140E1DE0073E130 /* PaddleMobile.swift in Sources */,
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
......
//
// PaddleMobile.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/9/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Metal
import MetalKit
import Foundation
public enum Platform{
case CPU, GPU
}
class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
}
}
public protocol Net {
var except: Int { get }
var dim: (n: Int, h: Int, w: Int, c: Int) { get }
var preprocessKernel: CusomKernel { get }
// var paramPointer: UnsafeMutableRawPointer { get }
// var paramSize: Int { get }
// var modelPointer: UnsafeMutableRawPointer { get }
// var modelSize: Int { get }
var modelPath: String { get }
var paramPath: String { get }
var modelDir: String { get }
func resultStr(res: [Float]) -> String
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32]
}
extension Net {
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr
}
}
public class Runner {
var program: Program?
var executor: Executor<Float32>?
var queue: MTLCommandQueue?
var textureLoader: MTKTextureLoader?
let net: Net
let device: MTLDevice?
let platform: Platform
/**
* inNet: 需要运行的网络
* commandQueue: GPU 是需要传入
* inPlatform: 需要使用的平台, GPU or CPU
*/
public init(inNet: Net, commandQueue: MTLCommandQueue?, inPlatform: Platform) {
net = inNet
queue = commandQueue
device = queue?.device
platform = inPlatform
if let inDevice = device {
textureLoader = MTKTextureLoader.init(device: inDevice)
}
}
/**
* load 模型, 返回 true 可进行预测
*/
public func load() -> Bool {
if platform == .GPU {
guard let inDevice = device, let inQueue = queue else {
print(" paddle mobile gpu load error, need MTLCommandQueue")
return false
}
let loader = Loader<Float32>.init()
do {
program = try loader.load(device: inDevice, modelPath: net.modelPath, paraPath: net.paramPath)
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!)
} catch let error {
print(error)
return false
}
} else {
print(" need implementation ")
return false
}
return true
}
/**
* CPU GPU 通用版本 predict
* cgImage: 需要预测的图片
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(cgImage: CGImage, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
if platform == .GPU {
getTexture(image: cgImage) { [weak self] (texture) in
guard let SSelf = self else {
fatalError()
}
SSelf.predict(texture: texture, completion: completion)
}
} else if platform == .CPU {
}
}
/**
* GPU 版本 predict
* texture: 需要预测的 texture 需要做过预处理
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
do {
try self.executor?.predict(input: texture, dim: [self.net.dim.n, self.net.dim.h, self.net.dim.w, self.net.dim.c], completionHandle: { [weak self] (res) in
guard let SSelf = self else {
fatalError( " self nil " )
}
let resultArray = SSelf.net.fetchResult(paddleMobileRes: res)
completion(true, res.elapsedTime, resultArray)
}, preProcessKernle: self.net.preprocessKernel, except: self.net.except)
} catch let error {
print(error)
completion(false, 0.0, [])
return
}
}
/*
* 清理内存, 调用此函数后, 不能再使用, 需重新 load
*/
public func clear() {
executor?.clear()
executor = nil
program = nil
}
/*
* 获取 texture, 对 texture 进行预处理, GPU 预测时使用
*/
public func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) {
let texture = try? textureLoader?.newTexture(cgImage: image, options: [:]) ?! " texture loader error"
scaleTexture(input: texture!, size: (net.dim.w, net.dim.h), complete: getTexture)
}
func scaleTexture(input: MTLTexture, size:(width: Int, height: Int), complete: @escaping (MTLTexture) -> Void) {
guard let inQueue = queue, let inDevice = device else {
fatalError( " queue or devcie nil " )
}
guard let buffer = inQueue.makeCommandBuffer() else {
fatalError( " make buffer error" )
}
let scaleKernel = ScaleKernel.init(device: inDevice, shape: CusomKernel.Shape.init(inWidth: size.width, inHeight: size.height, inChannel: 3))
do {
try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer)
} catch let error {
print(error)
fatalError()
}
buffer.addCompletedHandler { (buffer) in
complete(scaleKernel.outputTexture)
}
buffer.commit()
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册