提交 6dfb2ba3 编写于 作者: L liuruilong

add cpu

上级 8d11d7c9
......@@ -84,3 +84,5 @@ SwiftProtobuf.framework
paddle-mobile.xcworkspace
metal/models/
metal/images/
*.a
metal/paddle-mobile/paddle-mobile/CPU/libpaddle-mobile.a
......@@ -17,7 +17,6 @@
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; };
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; };
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; };
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD95021402B610073E130 /* PaddleMobile.swift */; };
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; };
FC8CFEE62135452C0094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE42135452B0094D569 /* genet_params */; };
FC8CFEE72135452C0094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEE52135452B0094D569 /* genet_model */; };
......@@ -66,7 +65,7 @@
FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; };
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; };
FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC4FD95021402B610073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = PaddleMobile.swift; path = ../../../../../../Desktop/PaddleMobile.swift; sourceTree = "<group>"; };
FC4FD97B2140EE250073E130 /* libc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libc++.tbd"; path = "usr/lib/libc++.tbd"; sourceTree = SDKROOT; };
FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC8CFEE42135452B0094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = "<group>"; };
FC8CFEE52135452B0094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = "<group>"; };
......@@ -108,6 +107,7 @@
7B7DED984E9EE7BFB45E24E8 /* Frameworks */ = {
isa = PBXGroup;
children = (
FC4FD97B2140EE250073E130 /* libc++.tbd */,
18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */,
);
name = Frameworks;
......@@ -176,7 +176,6 @@
FC8CFED2213519540094D569 /* Net */ = {
isa = PBXGroup;
children = (
FC4FD95021402B610073E130 /* PaddleMobile.swift */,
FC013927210204A3008100E3 /* PreProcessKernel.metal */,
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */,
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */,
......@@ -346,7 +345,6 @@
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */,
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */,
FC4FD95121402B610073E130 /* PaddleMobile.swift in Sources */,
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
);
......@@ -497,6 +495,7 @@
CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = A798K58VVL;
ENABLE_BITCODE = NO;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
......@@ -523,6 +522,7 @@
CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = A798K58VVL;
ENABLE_BITCODE = NO;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = (
......
......@@ -19,7 +19,6 @@ class AppDelegate: UIResponder, UIApplicationDelegate {
var window: UIWindow?
func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplicationLaunchOptionsKey: Any]?) -> Bool {
// Override point for customization after application launch.
return true
......
......@@ -17,6 +17,10 @@ import paddle_mobile
class Genet: Net {
var means: [Float] = [128.0, 128.0, 128.0]
var scale: Float = 0.017
let except: Int = 0
class GenetPreProccess: CusomKernel {
......
......@@ -19,6 +19,10 @@ import paddle_mobile
class MobileNet: Net{
var means: [Float] = [123.68, 116.78, 103.94]
var scale: Float = 0.017
let except: Int = 0
class MobilenetPreProccess: CusomKernel {
......
......@@ -16,6 +16,11 @@ import Foundation
import paddle_mobile
class MobileNet_ssd_hand: Net{
var means: [Float] = [123.68, 116.78, 103.94]
var scale: Float = 0.017
let except: Int = 2
class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) {
......
......@@ -17,10 +17,11 @@ import MetalKit
import paddle_mobile
import MetalPerformanceShaders
let platform: Platform = .CPU
let threadSupport = [1]
//.mobilenet : MobileNet.init(),
let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.init(), .genet : Genet.init()]
let modelHelperMap: [SupportModel : Runner] = [.mobilenet_ssd : Runner.init(inNet: MobileNet_ssd_hand.init(), commandQueue: MetalHelper.shared.queue, inPlatform: platform),
.genet : Runner.init(inNet: Genet.init(), commandQueue: MetalHelper.shared.queue, inPlatform: platform)]
//, .genet : Genet.init()
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
......@@ -30,7 +31,7 @@ enum SupportModel: String{
case genet = "genet"
static func supportedModels() -> [SupportModel] {
//.mobilenet,
return [.mobilenet_ssd ,.genet]
return [.mobilenet_ssd, .genet]
}
}
......@@ -40,12 +41,13 @@ class ViewController: UIViewController {
@IBOutlet weak var elapsedTimeLabel: UILabel!
@IBOutlet weak var modelPickerView: UIPickerView!
@IBOutlet weak var threadPickerView: UIPickerView!
var runnner: Runner!
var selectImage: UIImage?
var inputPointer: UnsafeMutablePointer<Float32>?
var modelType: SupportModel = SupportModel.supportedModels()[0]
var toPredictTexture: MTLTexture?
var net: Net {
var runner: Runner {
get {
return modelHelperMap[modelType] ?! " has no this type "
}
......@@ -56,8 +58,7 @@ class ViewController: UIViewController {
var threadNum = 1
@IBAction func loadAct(_ sender: Any) {
if runnner.load() {
if runner.load() {
print(" load success ! ")
} else {
print(" load error ! ")
......@@ -72,32 +73,66 @@ class ViewController: UIViewController {
}
@IBAction func clearAct(_ sender: Any) {
runnner.clear()
runner.clear()
}
@IBAction func predictAct(_ sender: Any) {
let max = 50
switch platform {
case .GPU:
guard let inTexture = toPredictTexture else {
resultTextView.text = "请选择图片 ! "
return
}
let max = 50
let startDate = Date.init()
for i in 0..<max {
runnner.predict(texture: inTexture) { [weak self] (success, time, result) in
runner.predict(texture: inTexture) { [weak self] (success, res) in
guard let sSelf = self else {
fatalError()
}
if success {
if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async {
sSelf.resultTextView.text = sSelf.runner.net.resultStr(res: res)
sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
}
}
}
}
}
case .CPU:
guard let inInputPointer = inputPointer else {
fatalError( " need input pointer " )
}
for _ in 0..<10 {
runner.predict(inputPointer: inInputPointer) { (success, res) in
}
}
let startDate = Date.init()
for i in 0..<max {
runner.predict(inputPointer: inInputPointer) { [weak self](success, res) in
guard let sSelf = self else {
fatalError()
}
if success {
if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async {
sSelf.resultTextView.text = sSelf.net.resultStr(res: result)
sSelf.resultTextView.text = sSelf.runner.net.resultStr(res: res)
sSelf.elapsedTimeLabel.text = "平均耗时: \(time/Double(max) * 1000.0) ms"
}
}
}
}
}
}
}
override func viewDidLoad() {
super.viewDidLoad()
......@@ -109,10 +144,15 @@ class ViewController: UIViewController {
selectImage = UIImage.init(named: "hand.jpg")
selectImageView.image = selectImage
runnner = Runner.init(inNet: net, commandQueue: MetalHelper.shared.queue, inPlatform: .GPU)
runnner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
if platform == .CPU {
inputPointer = runner.preproccess(image: selectImage!.cgImage!)
} else if platform == .GPU {
runner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture
}
} else {
fatalError( " unsupport " )
}
}
}
......@@ -166,7 +206,7 @@ extension ViewController: UIImagePickerControllerDelegate, UINavigationControll
}
sSelf.selectImage = image
sSelf.selectImageView.image = image
sSelf.runnner.getTexture(image: image.cgImage!, getTexture: { (texture) in
sSelf.runner.getTexture(image: image.cgImage!, getTexture: { (texture) in
sSelf.toPredictTexture = texture
})
}
......
......@@ -45,6 +45,9 @@
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; };
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; };
FC4FD9752140E1DE0073E130 /* PaddleMobile.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */; };
FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */ = {isa = PBXBuildFile; fileRef = FC4FD9772140E4980073E130 /* PaddleMobile.h */; settings = {ATTRIBUTES = (Public, ); }; };
FC4FD97A2140E4980073E130 /* libpaddle-mobile.a in Frameworks */ = {isa = PBXBuildFile; fileRef = FC4FD9782140E4980073E130 /* libpaddle-mobile.a */; };
FC4FD97E2140F2C30073E130 /* libstdc++.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = FC4FD97D2140F2C30073E130 /* libstdc++.tbd */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
......@@ -138,6 +141,9 @@
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; };
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; };
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PaddleMobile.swift; sourceTree = "<group>"; };
FC4FD9772140E4980073E130 /* PaddleMobile.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PaddleMobile.h; sourceTree = "<group>"; };
FC4FD9782140E4980073E130 /* libpaddle-mobile.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; path = "libpaddle-mobile.a"; sourceTree = "<group>"; };
FC4FD97D2140F2C30073E130 /* libstdc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libstdc++.tbd"; path = "usr/lib/libstdc++.tbd"; sourceTree = SDKROOT; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
......@@ -192,7 +198,9 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
FC4FD97E2140F2C30073E130 /* libstdc++.tbd in Frameworks */,
D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */,
FC4FD97A2140E4980073E130 /* libpaddle-mobile.a in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
......@@ -202,6 +210,7 @@
336CBE234BF5DE48658DE65F /* Frameworks */ = {
isa = PBXGroup;
children = (
FC4FD97D2140F2C30073E130 /* libstdc++.tbd */,
DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */,
);
name = Frameworks;
......@@ -237,6 +246,7 @@
FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = {
isa = PBXGroup;
children = (
FC4FD9762140E4920073E130 /* CPU */,
FC4FD9742140E1DE0073E130 /* PaddleMobile.swift */,
FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */,
......@@ -347,6 +357,15 @@
path = Kernels;
sourceTree = "<group>";
};
FC4FD9762140E4920073E130 /* CPU */ = {
isa = PBXGroup;
children = (
FC4FD9782140E4980073E130 /* libpaddle-mobile.a */,
FC4FD9772140E4980073E130 /* PaddleMobile.h */,
);
path = CPU;
sourceTree = "<group>";
};
FCD592FA20E248EC00252966 /* Base */ = {
isa = PBXGroup;
children = (
......@@ -398,6 +417,7 @@
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */,
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
......@@ -711,6 +731,7 @@
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_BITCODE = NO;
INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
......@@ -719,6 +740,10 @@
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
LIBRARY_SEARCH_PATHS = (
"$(inherited)",
"$(PROJECT_DIR)/paddle-mobile/CPU",
);
MACH_O_TYPE = mh_dylib;
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
......@@ -740,6 +765,7 @@
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_BITCODE = NO;
INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
......@@ -748,6 +774,10 @@
"@executable_path/Frameworks",
"@loader_path/Frameworks",
);
LIBRARY_SEARCH_PATHS = (
"$(inherited)",
"$(PROJECT_DIR)/paddle-mobile/CPU",
);
MACH_O_TYPE = mh_dylib;
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
......
......@@ -33,7 +33,7 @@
</AdditionalOptions>
</TestAction>
<LaunchAction
buildConfiguration = "Debug"
buildConfiguration = "Release"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#import <CoreImage/CoreImage.h>
#import <Foundation/Foundation.h>
@interface PaddleMobile : NSObject
/*
创建对象
*/
- (instancetype)init;
/*
load 模型, 开辟内存
*/
- (BOOL)load:(NSString *)modelPath andWeightsPath:(NSString *)weighsPath;
/*
加载散开形式的模型, 需传入模型的目录
*/
- (BOOL)load:(NSString *)modelAndWeightPath;
/*
* 从内存中加载模型
* */
- (BOOL)LoadCombinedMemory:(size_t)modelLen
andModelBuf:(const uint8_t *)modelBuf
andModelParamsLen:(size_t)combinedParamsLen
andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf;
/*
* 进行预测, means 和 scale 为训练模型时的预处理参数, 如训练时没有做这些预处理则直接使用 predict
*/
- (NSArray *)predict:(CGImageRef)image
dim:(NSArray<NSNumber *> *)dim
means:(NSArray<NSNumber *> *)means
scale:(float)scale;
/*
* 预测输入
* */
- (NSArray *)predictInput:(float *)input
dim:(NSArray<NSNumber *> *)dim
means:(NSArray<NSNumber *> *)means
scale:(float)scale;
/*
* 对图像进行预处理
* */
-(void)preprocess:(CGImageRef)image
output:(float *)output
means:(NSArray<NSNumber *> *)means
scale:(float)scale
dim:(NSArray<NSNumber *> *)dim;
/*
清理内存
*/
- (void)clear;
@end
......@@ -22,6 +22,8 @@ class ScaleKernel: CusomKernel {
public protocol Net {
var except: Int { get }
var means: [Float] { get }
var scale: Float { get }
var dim: (n: Int, h: Int, w: Int, c: Int) { get }
var preprocessKernel: CusomKernel { get }
// var paramPointer: UnsafeMutableRawPointer { get }
......@@ -36,7 +38,7 @@ public protocol Net {
}
extension Net {
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
public func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr
}
}
......@@ -46,9 +48,15 @@ public class Runner {
var executor: Executor<Float32>?
var queue: MTLCommandQueue?
var textureLoader: MTKTextureLoader?
let net: Net
public let net: Net
let device: MTLDevice?
let platform: Platform
var cpuPaddleMobile: PaddleMobile?
let numel: Int
let meansNumber: [NSNumber]
// dims num nchw
let dimsNum: [NSNumber]
/**
* inNet: 需要运行的网络
* commandQueue: GPU 是需要传入
......@@ -62,6 +70,15 @@ public class Runner {
if let inDevice = device {
textureLoader = MTKTextureLoader.init(device: inDevice)
}
if platform == .CPU {
cpuPaddleMobile = PaddleMobile.init()
}
numel = net.dim.n * net.dim.c * net.dim.h * net.dim.w
meansNumber = net.means.map { NSNumber.init(value: $0) }
dimsNum = [NSNumber.init(value: net.dim.n),
NSNumber.init(value: net.dim.c),
NSNumber.init(value: net.dim.h),
NSNumber.init(value: net.dim.w)]
}
/**
......@@ -82,56 +99,83 @@ public class Runner {
return false
}
} else {
print(" need implementation ")
return false
return cpuPaddleMobile?.load(net.modelPath, andWeightsPath: net.paramPath) ?? false
}
return true
}
/**
* CPU GPU 通用版本 predict
* cgImage: 需要预测的图片
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(cgImage: CGImage, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
if platform == .GPU {
getTexture(image: cgImage) { [weak self] (texture) in
guard let SSelf = self else {
fatalError()
public func predict(inputPointer: UnsafeMutablePointer<Float32>, completion: @escaping ( _ success: Bool, _ resultArray: [Float32]) -> Void) {
guard let res = cpuPaddleMobile?.predictInput(inputPointer, dim: dimsNum, means: meansNumber, scale: net.scale) else {
completion(false, [])
return
}
SSelf.predict(texture: texture, completion: completion)
completion(true, res.map { ($0 as! NSNumber).floatValue })
}
} else if platform == .CPU {
}
}
/**
* GPU 版本 predict
* texture: 需要预测的 texture 需要做过预处理
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void) {
public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ resultArray: [Float32]) -> Void) {
do {
try self.executor?.predict(input: texture, dim: [self.net.dim.n, self.net.dim.h, self.net.dim.w, self.net.dim.c], completionHandle: { [weak self] (res) in
guard let SSelf = self else {
fatalError( " self nil " )
}
let resultArray = SSelf.net.fetchResult(paddleMobileRes: res)
completion(true, res.elapsedTime, resultArray)
completion(true, resultArray)
}, preProcessKernle: self.net.preprocessKernel, except: self.net.except)
} catch let error {
print(error)
completion(false, 0.0, [])
completion(false, [])
return
}
}
/**
* CPU GPU 通用版本 predict
* cgImage: 需要预测的图片
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组
*/
public func predict(cgImage: CGImage, completion: @escaping ( _ success: Bool, _ resultArray: [Float32]) -> Void) {
if platform == .GPU {
getTexture(image: cgImage) { [weak self] (texture) in
guard let SSelf = self else {
fatalError( "" )
}
SSelf.predict(texture: texture, completion: completion)
}
} else if platform == .CPU {
let input = preproccess(image: cgImage)
predict(inputPointer: input, completion: completion)
input.deinitialize(count: numel)
input.deallocate()
}
}
/*
* 清理内存, 调用此函数后, 不能再使用, 需重新 load
*/
public func clear() {
if platform == .GPU {
executor?.clear()
executor = nil
program = nil
} else if platform == .CPU {
cpuPaddleMobile?.clear()
}
}
public func preproccess(image: CGImage) -> UnsafeMutablePointer<Float> {
let output = UnsafeMutablePointer<Float>.allocate(capacity: numel)
let means = net.means.map { NSNumber.init(value: $0) }
let dims = [NSNumber.init(value: net.dim.n),
NSNumber.init(value: net.dim.c),
NSNumber.init(value: net.dim.h),
NSNumber.init(value: net.dim.w)]
cpuPaddleMobile?.preprocess(image, output: output, means: means, scale: net.scale, dim: dims)
return output
}
/*
......@@ -169,7 +213,3 @@ public class Runner {
}
......@@ -14,12 +14,14 @@
#pragma once
#import "PaddleMobile.h"
#import <UIKit/UIKit.h>
//! Project version number for paddle_mobile.
FOUNDATION_EXPORT double paddle_mobileVersionNumber;
//FOUNDATION_EXPORT double paddle_mobileVersionNumber;
//! Project version string for paddle_mobile.
FOUNDATION_EXPORT const unsigned char paddle_mobileVersionString[];
//FOUNDATION_EXPORT const unsigned char paddle_mobileVersionString[];
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册