提交 ee7ed16f 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #913 from codeWorm2015/metal

add new interface
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; }; FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; };
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; }; FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; }; FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC27991021341CE5000B6BAD /* Net.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC27990F21341CE5000B6BAD /* Net.swift */; };
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; }; FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; };
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; }; FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; };
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; }; FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; };
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD95021402B610073E130 /* PaddleMobile.swift */; };
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; }; FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; };
FC8CFEE62135452C0094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE42135452B0094D569 /* genet_params */; }; FC8CFEE62135452C0094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE42135452B0094D569 /* genet_params */; };
FC8CFEE72135452C0094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE52135452B0094D569 /* genet_model */; }; FC8CFEE72135452C0094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE52135452B0094D569 /* genet_model */; };
...@@ -61,12 +61,12 @@ ...@@ -61,12 +61,12 @@
FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; }; FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; }; FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; }; FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC27990F21341CE5000B6BAD /* Net.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Net.swift; sourceTree = "<group>"; };
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; }; FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; };
FC27991221343A3A000B6BAD /* CPUCompute.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = "<group>"; }; FC27991221343A3A000B6BAD /* CPUCompute.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = "<group>"; };
FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; }; FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; };
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; }; FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; };
FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; }; FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC4FD95021402B610073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = PaddleMobile.swift; path = ../../../../../../Desktop/PaddleMobile.swift; sourceTree = "<group>"; };
FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; }; FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC8CFEE42135452B0094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = "<group>"; }; FC8CFEE42135452B0094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = "<group>"; };
FC8CFEE52135452B0094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = "<group>"; }; FC8CFEE52135452B0094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = "<group>"; };
...@@ -176,11 +176,11 @@ ...@@ -176,11 +176,11 @@
FC8CFED2213519540094D569 /* Net */ = { FC8CFED2213519540094D569 /* Net */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC4FD95021402B610073E130 /* PaddleMobile.swift */,
FC013927210204A3008100E3 /* PreProcessKernel.metal */, FC013927210204A3008100E3 /* PreProcessKernel.metal */,
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */, FCBCCC542122EF5400D94F7E /* MetalHelper.swift */,
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */, FC3C800E2133F46600D1295E /* MobileNetSSD.swift */,
FC3C80102133F4AB00D1295E /* MobileNet.swift */, FC3C80102133F4AB00D1295E /* MobileNet.swift */,
FC27990F21341CE5000B6BAD /* Net.swift */,
FC27991221343A3A000B6BAD /* CPUCompute.mm */, FC27991221343A3A000B6BAD /* CPUCompute.mm */,
FC27991421343A46000B6BAD /* CPUCompute.h */, FC27991421343A46000B6BAD /* CPUCompute.h */,
FC8CFEE1213524EA0094D569 /* Genet.swift */, FC8CFEE1213524EA0094D569 /* Genet.swift */,
...@@ -342,11 +342,11 @@ ...@@ -342,11 +342,11 @@
files = ( files = (
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */, FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */,
FC27991021341CE5000B6BAD /* Net.swift in Sources */,
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */, FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */,
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */, FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */, FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */,
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */, FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */,
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */,
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */, FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */, FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
); );
...@@ -494,19 +494,19 @@ ...@@ -494,19 +494,19 @@
buildSettings = { buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution"; CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Manual; CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6T9LLJKSM4; DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
); );
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 4.0; SWIFT_VERSION = 4.0;
...@@ -520,19 +520,19 @@ ...@@ -520,19 +520,19 @@
buildSettings = { buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution"; CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Manual; CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6T9LLJKSM4; DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
); );
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_VERSION = 4.0; SWIFT_VERSION = 4.0;
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";
......
...@@ -17,10 +17,6 @@ import paddle_mobile ...@@ -17,10 +17,6 @@ import paddle_mobile
class Genet: Net { class Genet: Net {
var program: Program?
var executor: Executor<Float32>?
let except: Int = 0 let except: Int = 0
class GenetPreProccess: CusomKernel { class GenetPreProccess: CusomKernel {
......
...@@ -28,25 +28,6 @@ class MetalHelper { ...@@ -28,25 +28,6 @@ class MetalHelper {
textureLoader = MTKTextureLoader.init(device: device) textureLoader = MTKTextureLoader.init(device: device)
} }
static func scaleTexture(queue: MTLCommandQueue, input: MTLTexture, size:(width: Int, height: Int), complete: @escaping (MTLTexture) -> Void) {
guard let buffer = queue.makeCommandBuffer() else {
fatalError()
}
let scaleKernel = ScaleKernel.init(device: MetalHelper.shared.device, shape: CusomKernel.Shape.init(inWidth: size.width, inHeight: size.height, inChannel: 3))
do {
try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer)
} catch let error {
print(error)
fatalError()
}
buffer.addCompletedHandler { (buffer) in
complete(scaleKernel.outputTexture)
}
buffer.commit()
}
} }
...@@ -15,12 +15,10 @@ ...@@ -15,12 +15,10 @@
import Foundation import Foundation
import paddle_mobile import paddle_mobile
class MobileNet: Net{ class MobileNet: Net{
var program: Program?
var executor: Executor<Float32>?
let except: Int = 0 let except: Int = 0
class MobilenetPreProccess: CusomKernel { class MobilenetPreProccess: CusomKernel {
......
...@@ -14,13 +14,8 @@ ...@@ -14,13 +14,8 @@
import Foundation import Foundation
import paddle_mobile import paddle_mobile
//import
class MobileNet_ssd_hand: Net{ class MobileNet_ssd_hand: Net{
var program: Program?
var executor: Executor<Float32>?
let except: Int = 2 let except: Int = 2
class MobilenetssdPreProccess: CusomKernel { class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
...@@ -83,36 +78,11 @@ class MobileNet_ssd_hand: Net{ ...@@ -83,36 +78,11 @@ class MobileNet_ssd_hand: Net{
let modelDir: String let modelDir: String
// let paramPointer: UnsafeMutableRawPointer
//
// let paramSize: Int
//
// let modelPointer: UnsafeMutableRawPointer
//
// let modelSize: Int
//
// /**
// * inParamPointer: 参数文件内存地址
// * inParamSize: 参数文件大小(字节数)
// * inModelPointer: 模型文件内存地址
// * inModelSize: 模型文件大小(字节数)
// */
// init(inParamPointer: UnsafeMutableRawPointer, inParamSize: Int, inModelPointer: UnsafeMutableRawPointer, inModelSize: Int) {
// paramPointer = inParamPointer
// paramSize = inParamSize
// modelPointer = inModelPointer
// modelSize = inModelSize
//// fatalError()
// }
init() { init() {
modelPath = Bundle.main.path(forResource: "ssd_hand_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "ssd_hand_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "ssd_hand_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "ssd_hand_params", ofType: nil) ?! "para null"
modelDir = "" modelDir = ""
preprocessKernel = MobilenetssdPreProccess.init(device: MetalHelper.shared.device) preprocessKernel = MobilenetssdPreProccess.init(device: MetalHelper.shared.device)
// fatalError()
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
import UIKit
import MetalKit
import Foundation
import paddle_mobile
import MetalPerformanceShaders
class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
}
}
protocol Net {
var program: Program? { get set }
var executor: Executor<Float32>? { get set }
var except: Int { get }
var dim: (n: Int, h: Int, w: Int, c: Int) { get }
var modelPath: String { get }
var paramPath: String { get }
var modelDir: String { get }
var preprocessKernel: CusomKernel { get }
func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void)
func resultStr(res: [Float]) -> String
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32]
mutating func load() throws
func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws
mutating func clear()
}
extension Net {
mutating func load() throws {
let queue = MetalHelper.shared.queue
let loader = Loader<Float32>.init()
do {
program = try loader.load(device: MetalHelper.shared.device, modelPath: modelPath, paraPath: paramPath)
executor = try Executor<Float32>.init(inDevice: MetalHelper.shared.device, inQueue: queue, inProgram: program!)
} catch let error {
throw error
}
}
func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws {
guard let inExecutor = executor else {
fatalError(" 请先 load ")
}
try inExecutor.predict(input: inTexture, dim: [dim.n, dim.h, dim.w, dim.c], completionHandle: { (result) in
var resultArr:[Float32] = []
resultArr = self.fetchResult(paddleMobileRes: result)
completion((time: TimeInterval(result.elapsedTime), resultArray: resultArr))
}, preProcessKernle: preprocessKernel, except: except)
}
mutating func clear() {
executor?.clear()
program = nil
executor = nil
}
func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) {
let texture = try? MetalHelper.shared.textureLoader.newTexture(cgImage: image, options: [:]) ?! " texture loader error"
MetalHelper.scaleTexture(queue: MetalHelper.shared.queue, input: texture!, size: (dim.w, dim.h)) { (resTexture) in
getTexture(resTexture)
}
}
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr
}
}
//
// PaddleMobile.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/9/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
...@@ -25,7 +25,7 @@ let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand. ...@@ -25,7 +25,7 @@ let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()] //let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
enum SupportModel: String{ enum SupportModel: String{
// case mobilenet = "mobilenet" // case mobilenet = "mobilenet"
case mobilenet_ssd = "mobilenetssd" case mobilenet_ssd = "mobilenetssd"
case genet = "genet" case genet = "genet"
static func supportedModels() -> [SupportModel] { static func supportedModels() -> [SupportModel] {
...@@ -40,6 +40,7 @@ class ViewController: UIViewController { ...@@ -40,6 +40,7 @@ class ViewController: UIViewController {
@IBOutlet weak var elapsedTimeLabel: UILabel! @IBOutlet weak var elapsedTimeLabel: UILabel!
@IBOutlet weak var modelPickerView: UIPickerView! @IBOutlet weak var modelPickerView: UIPickerView!
@IBOutlet weak var threadPickerView: UIPickerView! @IBOutlet weak var threadPickerView: UIPickerView!
var runnner: Runner!
var selectImage: UIImage? var selectImage: UIImage?
var modelType: SupportModel = SupportModel.supportedModels()[0] var modelType: SupportModel = SupportModel.supportedModels()[0]
var toPredictTexture: MTLTexture? var toPredictTexture: MTLTexture?
...@@ -56,10 +57,10 @@ class ViewController: UIViewController { ...@@ -56,10 +57,10 @@ class ViewController: UIViewController {
@IBAction func loadAct(_ sender: Any) { @IBAction func loadAct(_ sender: Any) {
do { if runnner.load() {
try self.net.load() print(" load success ! ")
} catch let error { } else {
print(error) print(" load error ! ")
} }
} }
...@@ -71,7 +72,7 @@ class ViewController: UIViewController { ...@@ -71,7 +72,7 @@ class ViewController: UIViewController {
} }
@IBAction func clearAct(_ sender: Any) { @IBAction func clearAct(_ sender: Any) {
net.clear() runnner.clear()
} }
@IBAction func predictAct(_ sender: Any) { @IBAction func predictAct(_ sender: Any) {
...@@ -79,26 +80,22 @@ class ViewController: UIViewController { ...@@ -79,26 +80,22 @@ class ViewController: UIViewController {
resultTextView.text = "请选择图片 ! " resultTextView.text = "请选择图片 ! "
return return
} }
do { let max = 50
let max = 50 let startDate = Date.init()
let startDate = Date.init() for i in 0..<max {
for i in 0..<max { runnner.predict(texture: inTexture) { [weak self] (success, time, result) in
try net.predict(inTexture: inTexture) { [weak self] (result) in guard let sSelf = self else {
guard let sSelf = self else { fatalError()
fatalError() }
}
if i == max - 1 {
if i == max - 1 { let time = Date.init().timeIntervalSince(startDate)
let time = Date.init().timeIntervalSince(startDate) DispatchQueue.main.async {
DispatchQueue.main.async { sSelf.resultTextView.text = sSelf.net.resultStr(res: result)
sSelf.resultTextView.text = sSelf.net.resultStr(res: result.resultArray) sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
}
} }
} }
} }
} catch let error {
print(error)
} }
} }
...@@ -111,7 +108,9 @@ class ViewController: UIViewController { ...@@ -111,7 +108,9 @@ class ViewController: UIViewController {
selectImage = UIImage.init(named: "hand.jpg") selectImage = UIImage.init(named: "hand.jpg")
selectImageView.image = selectImage selectImageView.image = selectImage
net.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
runnner = Runner.init(inNet: net, commandQueue: MetalHelper.shared.queue, inPlatform: .GPU)
runnner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture self?.toPredictTexture = texture
} }
} }
...@@ -167,7 +166,7 @@ extension ViewController: UIImagePickerControllerDelegate, UINavigationControll ...@@ -167,7 +166,7 @@ extension ViewController: UIImagePickerControllerDelegate, UINavigationControll
} }
sSelf.selectImage = image sSelf.selectImage = image
sSelf.selectImageView.image = image sSelf.selectImageView.image = image
sSelf.net.getTexture(image: image.cgImage!, getTexture: { (texture) in sSelf.runnner.getTexture(image: image.cgImage!, getTexture: { (texture) in
sSelf.toPredictTexture = texture sSelf.toPredictTexture = texture
}) })
} }
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; }; FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; };
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; }; FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; };
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; }; FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; };
FC4FD9752140E1DE0073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; }; FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; }; FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; }; FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
...@@ -136,6 +137,7 @@ ...@@ -136,6 +137,7 @@
FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = "<group>"; }; FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = "<group>"; };
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; }; FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; };
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; }; FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; };
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PaddleMobile.swift; sourceTree = "<group>"; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; }; FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; }; FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; }; FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
...@@ -235,10 +237,9 @@ ...@@ -235,10 +237,9 @@
FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = { FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */,
FC039BAE20E11CC20081E9F8 /* Program */, FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */, FC039BA320E11CBC0081E9F8 /* Operators */,
FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9C20E11CB20081E9F8 /* framework */, FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */, FC039B9320E11C9A0081E9F8 /* Common */,
FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */, FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */,
...@@ -263,6 +264,8 @@ ...@@ -263,6 +264,8 @@
FC039B9C20E11CB20081E9F8 /* framework */ = { FC039B9C20E11CB20081E9F8 /* framework */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9D20E11CB20081E9F8 /* Tensor.swift */, FC039B9D20E11CB20081E9F8 /* Tensor.swift */,
FC039B9E20E11CB20081E9F8 /* Dim.swift */, FC039B9E20E11CB20081E9F8 /* Dim.swift */,
FC9D038320E23B01000F735A /* Texture.swift */, FC9D038320E23B01000F735A /* Texture.swift */,
...@@ -524,6 +527,7 @@ ...@@ -524,6 +527,7 @@
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */, FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */, FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */,
FCA3A1632132A4AC00084FE5 /* ReshapeKernel.metal in Sources */, FCA3A1632132A4AC00084FE5 /* ReshapeKernel.metal in Sources */,
FC4FD9752140E1DE0073E130 /* PaddleMobile.swift in Sources */,
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */,
......
//
// CNNConvAddBatchNormReluOp.swift
// paddle-mobile
import Foundation
class CNNMPSConvTestParam: TestParam {
var outputTexture: MTLTexture?
var metalParam: MetalConvParam
let filterPointer: UnsafeMutableRawPointer
let biasePointer: UnsafeMutablePointer<Float>
let filterSize: (width: Int, height: Int, channel: Int)
init(inMetalParam: MetalConvParam, inFilter: [Float], inBiase: [Float], inFilterSize: (width: Int, height: Int, channel: Int)) {
metalParam = inMetalParam
filterPointer = UnsafeMutableRawPointer.init(mutating: inFilter)
biasePointer = UnsafeMutablePointer.init(mutating: inBiase)
filterSize = inFilterSize
}
}
@available(iOS 10.0, *)
class CNNMPSConvOp<P: PrecisionType>: Operator<CNNConvKernel<P>, CNNConvParam<P>>, Runable, Creator, InferShaperable, Fusion {
typealias OpType = CNNMPSConvOp<P>
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
fatalError()
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
}
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode-->Node.init(inType: gElementwiseAdd);
return beginNode
}
static func change() -> [String : [(from: String, to: String)]] {
return [:]
}
static func fusionType() -> String {
return gMPSCNNConvType
}
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
let strides = para.stride
let paddings = para.paddings
let dilations = para.dilations
var outDim = [inDims[0]]
for i in 0..<strides.count {
let dilation: Int = Int(dilations[i])
let filterSize: Int = filterDim[i + 1]
let inputSize: Int = inDims[i + 1]
let padding: Int = Int(paddings[i])
let stride: Int = Int(strides[i])
let dKernel = dilation * (filterSize - 1) + 1
let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
outDim.append(outputSize)
}
outDim.append(filterDim[0])
para.output.dim = Dim.init(inDim: outDim)
}
}
//
// BatchNormRelu.swift
// paddle-mobile
//
// Created by zhangxinjun on 2018/8/23.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class BatchNormReluParam<P: PrecisionType>: BatchNormParam<P> {
}
class BatchNormReluKernel<P: PrecisionType>: Kernel, Computable{
typealias ParamType = BatchNormReluParam<P>
var newScale: MTLBuffer
var newBias: MTLBuffer
required init(device: MTLDevice, testParam: BatchNormReluTestParam) {
newScale = testParam.newScaleBuffer
newBias = testParam.newBiaseBuffer
super.init(device: device, inFunctionName: "batch_norm_relu_3x3")
}
required init(device: MTLDevice, param: BatchNormReluParam<P>) {
guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
fatalError()
}
guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
fatalError()
}
self.newScale = newScale
self.newBias = newBias
super.init(device: device, inFunctionName: "batch_norm_relu_3x3")
let varianceBuffer : MTLBuffer = param.inputVariance.buffer
var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
}
let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
let scale : MTLBuffer = param.inputScale.buffer
let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
let bias : MTLBuffer = param.inputBias.buffer
let biasContents = bias.contents().assumingMemoryBound(to: P.self)
let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
}
}
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormReluParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError()
}
encoder.setTexture(param.input as? MTLTexture, index: 0)
encoder.setTexture(param.output as? MTLTexture, index: 1)
encoder.setBuffer(newScale, offset: 0, index: 1)
encoder.setBuffer(newBias, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output as! MTLTexture)
encoder.endEncoding()
}
func testCompute(commandBuffer: MTLCommandBuffer, testParam: BatchNormReluTestParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError()
}
encoder.setTexture(testParam.inputTexture, index: 0)
encoder.setTexture(testParam.outputTexture, index: 1)
encoder.setBuffer(newScale, offset: 0, index: 0)
encoder.setBuffer(newBias, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: testParam.outputTexture)
encoder.endEncoding()
}
}
//
// CNNConvKernel.swift
// paddle-mobile
//
import Foundation
import Metal
import Accelerate
import MetalPerformanceShaders
@available(iOS 10.0, *)
class WeightsDataSource: NSObject, MPSCNNConvolutionDataSource {
let desc: MPSCNNConvolutionDescriptor
let weight:UnsafeMutableRawPointer
let bias:UnsafeMutablePointer<Float>
init(inDesc: MPSCNNConvolutionDescriptor, inWeight: UnsafeMutableRawPointer, inBias: UnsafeMutablePointer<Float>) {
desc = inDesc
weight = inWeight
bias = inBias
}
func dataType() -> MPSDataType {
return .float32
}
func descriptor() -> MPSCNNConvolutionDescriptor {
return desc
}
func weights() -> UnsafeMutableRawPointer {
return self.weight
}
func biasTerms() -> UnsafeMutablePointer<Float>? {
return self.bias
}
func load() -> Bool {
return true
}
func purge() {
}
func label() -> String? {
return "Conv"
}
}
@available(iOS 10.0, *)
class CNNConvParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
filter = try CNNConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try CNNConvParam.input(inputs: opDesc.inputs, from: inScope)
output = try CNNConvParam.outputOut(outputs: opDesc.outputs, from: inScope)
stride = try CNNConvParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try CNNConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
// 暂时不用关心
dilations = try CNNConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
// 暂时不用关心
groups = try CNNConvParam.getAttr(key: "groups", attrs: opDesc.attrs)
variance = try CNNConvParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
// bias
y = try CNNConvParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch let error {
throw error
}
}
var input: Texture<P>
let variance: Tensor<ParamPrecisionType>
let y: Tensor<ParamPrecisionType>
let filter: Tensor<ParamPrecisionType>
var output: Texture<P>
let stride: [Int32]
let paddings: [Int32]
let dilations: [Int32]
let groups: Int
}
@available(iOS 10.0, *)
class CNNConvKernel<P: PrecisionType>: Kernel, Computable {
typealias ParamType = CNNConvParam<P>
var mpsImageCreator: MpsImageCreator<P>?
var activation:MPSCNNNeuron?
var conv:MPSCNNConvolution?
var weightDataSource:WeightsDataSource?
var param: CNNConvParam<P>?
var device: MTLDevice?
required init(device:MTLDevice, testParam:CNNMPSConvTestParam) {
self.device = device
let desc = MPSCNNConvolutionDescriptor(kernelWidth: testParam.filterSize.width, kernelHeight: testParam.filterSize.height, inputFeatureChannels: testParam.filterSize.channel, outputFeatureChannels: testParam.filterSize.channel, neuronFilter: activation)
desc.strideInPixelsX = Int(testParam.metalParam.offsetX)
desc.strideInPixelsY = Int(testParam.metalParam.offsetY)
weightDataSource = WeightsDataSource(inDesc: desc, inWeight:testParam.filterPointer, inBias:testParam.biasePointer)
if #available(iOS 11.0, *) {
conv = MPSCNNConvolution(device: self.device!, weights: weightDataSource!)
} else {
// Fallback on earlier versions
}
super.init(device: device, inFunctionName: "")
}
required init(device:MTLDevice, param:CNNConvParam<P>) {
self.device = device
let inChannels: Int
let outChannels: Int
if param.y.dim.cout() == 4 {
inChannels = (param.y.dim[3])
outChannels = inChannels
} else {
inChannels = 0
outChannels = inChannels
}
let desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.width, kernelHeight: param.filter.height, inputFeatureChannels: inChannels, outputFeatureChannels: outChannels, neuronFilter: activation)
desc.strideInPixelsX = Int(param.stride[0])
desc.strideInPixelsY = Int(param.stride[1])
weightDataSource = WeightsDataSource(inDesc: desc, inWeight:param.filter.data.pointer as! UnsafeMutablePointer<Float>, inBias: param.y.data.pointer as! UnsafeMutablePointer<Float>)
if #available(iOS 11.0, *) {
conv = MPSCNNConvolution(device: self.device!, weights: weightDataSource!)
} else {
// Fallback on earlier versions
}
super.init(device: device, inFunctionName: "")
}
func compute(commandBuffer: MTLCommandBuffer, param: CNNConvParam<P>) throws {
let inputImage:MPSImage = (mpsImageCreator?.createMPSImage(device: device!))!
var outputImage = (mpsImageCreator?.createMPSImage(device: device!))!
// 运算conv和add两个步骤,add用了bias偏差做为参数,被Metal API进行调用
conv?.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
param.input = outputImage.texture as! Texture<P>
}
func testCompute(commandBuffer: MTLCommandBuffer, testParam: CNNMPSConvTestParam) throws {
let inputImage:MPSImage = (mpsImageCreator?.createMPSImage(device: device!))!
var outputImage = (mpsImageCreator?.createMPSImage(device: device!))!
// 运算conv和add两个步骤,add用了bias偏差做为参数,被Metal API进行调用
conv?.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
testParam.outputTexture = outputImage.texture
}
}
//
// BatchNormRelu.metal
// paddle-mobile
//
#include <metal_stdlib>
using namespace metal;
struct MetalConvParam {
short offsetX;
short offsetY;
short offsetZ;
ushort strideX;
ushort strideY;
};
kernel void batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *new_scale [[buffer(0)]],
const device float4 *new_biase [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
float4 input;
float4 output;
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
input = inTexture.sample(sample, gid.x, gid.y, gid.z);
output = fmax(input * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
//
// PaddleMobile.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/9/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Metal
import MetalKit
import Foundation
public enum Platform{
case CPU, GPU
}
class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
}
}
public protocol Net {
var except: Int { get }
var dim: (n: Int, h: Int, w: Int, c: Int) { get }
var preprocessKernel: CusomKernel { get }
// var paramPointer: UnsafeMutableRawPointer { get }
// var paramSize: Int { get }
// var modelPointer: UnsafeMutableRawPointer { get }
// var modelSize: Int { get }
var modelPath: String { get }
var paramPath: String { get }
var modelDir: String { get }
func resultStr(res: [Float]) -> String
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32]
}
extension Net {
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr
}
}
public class Runner {
var program: Program?
var executor: Executor<Float32>?
var queue: MTLCommandQueue?
var textureLoader: MTKTextureLoader?
let net: Net
let device: MTLDevice?
let platform: Platform
/**
* inNet: 需要运行的网络
* commandQueue: GPU 是需要传入
* inPlatform: 需要使用的平台, GPU or CPU
*/
public init(inNet: Net, commandQueue: MTLCommandQueue?, inPlatform: Platform) {
net = inNet
queue = commandQueue
device = queue?.device
platform = inPlatform
if let inDevice = device {
textureLoader = MTKTextureLoader.init(device: inDevice)
}
}
/**
* load 模型, 返回 true 可进行预测
*/
public func load() -> Bool {
if platform == .GPU {
guard let inDevice = device, let inQueue = queue else {
print(" paddle mobile gpu load error, need MTLCommandQueue")
return false
}
let loader = Loader<Float32>.init()
do {
program = try loader.load(device: inDevice, modelPath: net.modelPath, paraPath: net.paramPath)
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!)
} catch let error {
print(error)
return false
}
} else {
print(" need implementation ")
return false
}
return true
}
/**
* CPU GPU 通用版本 predict
* cgImage: 需要预测的图片
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(cgImage: CGImage, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
if platform == .GPU {
getTexture(image: cgImage) { [weak self] (texture) in
guard let SSelf = self else {
fatalError()
}
SSelf.predict(texture: texture, completion: completion)
}
} else if platform == .CPU {
}
}
/**
* GPU 版本 predict
* texture: 需要预测的 texture 需要做过预处理
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
do {
try self.executor?.predict(input: texture, dim: [self.net.dim.n, self.net.dim.h, self.net.dim.w, self.net.dim.c], completionHandle: { [weak self] (res) in
guard let SSelf = self else {
fatalError( " self nil " )
}
let resultArray = SSelf.net.fetchResult(paddleMobileRes: res)
completion(true, res.elapsedTime, resultArray)
}, preProcessKernle: self.net.preprocessKernel, except: self.net.except)
} catch let error {
print(error)
completion(false, 0.0, [])
return
}
}
/*
* 清理内存, 调用此函数后, 不能再使用, 需重新 load
*/
public func clear() {
executor?.clear()
executor = nil
program = nil
}
/*
* 获取 texture, 对 texture 进行预处理, GPU 预测时使用
*/
public func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) {
let texture = try? textureLoader?.newTexture(cgImage: image, options: [:]) ?! " texture loader error"
scaleTexture(input: texture!, size: (net.dim.w, net.dim.h), complete: getTexture)
}
func scaleTexture(input: MTLTexture, size:(width: Int, height: Int), complete: @escaping (MTLTexture) -> Void) {
guard let inQueue = queue, let inDevice = device else {
fatalError( " queue or devcie nil " )
}
guard let buffer = inQueue.makeCommandBuffer() else {
fatalError( " make buffer error" )
}
let scaleKernel = ScaleKernel.init(device: inDevice, shape: CusomKernel.Shape.init(inWidth: size.width, inHeight: size.height, inChannel: 3))
do {
try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer)
} catch let error {
print(error)
fatalError()
}
buffer.addCompletedHandler { (buffer) in
complete(scaleKernel.outputTexture)
}
buffer.commit()
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册