未验证 提交 1c72c0ef 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #989 from codeWorm2015/metal

update program when  mobilentfssd
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
*.lai *.lai
*.la *.la
*.lib *.lib
*.a
# Executables # Executables
*.exe *.exe
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; }; FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
FCEEE7D4210627A000444BEC /* banana.jpeg in Resources */ = {isa = PBXBuildFile; fileRef = FCEEE7D3210627A000444BEC /* banana.jpeg */; }; FCEEE7D4210627A000444BEC /* banana.jpeg in Resources */ = {isa = PBXBuildFile; fileRef = FCEEE7D3210627A000444BEC /* banana.jpeg */; };
FCF437E8214B6DDB00943429 /* Multi-Predict-ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF437E7214B6DDB00943429 /* Multi-Predict-ViewController.swift */; }; FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXCopyFilesBuildPhase section */ /* Begin PBXCopyFilesBuildPhase section */
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
FCDFD41A211D91C7005AB38B /* synset.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = synset.txt; sourceTree = "<group>"; }; FCDFD41A211D91C7005AB38B /* synset.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = synset.txt; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
FCEEE7D3210627A000444BEC /* banana.jpeg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = banana.jpeg; sourceTree = "<group>"; }; FCEEE7D3210627A000444BEC /* banana.jpeg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = banana.jpeg; sourceTree = "<group>"; };
FCF437E7214B6DDB00943429 /* Multi-Predict-ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Multi-Predict-ViewController.swift"; sourceTree = "<group>"; }; FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MultiPredictViewController.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */ /* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */ /* Begin PBXFrameworksBuildPhase section */
...@@ -147,7 +147,7 @@ ...@@ -147,7 +147,7 @@
FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */, FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */,
FC039B8D20E11C560081E9F8 /* Info.plist */, FC039B8D20E11C560081E9F8 /* Info.plist */,
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */, FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */,
FCF437E7214B6DDB00943429 /* Multi-Predict-ViewController.swift */, FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */,
); );
path = "paddle-mobile-demo"; path = "paddle-mobile-demo";
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -361,7 +361,7 @@ ...@@ -361,7 +361,7 @@
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC803BCE214D27930094B8E5 /* VideoCapture.swift in Sources */, FC803BCE214D27930094B8E5 /* VideoCapture.swift in Sources */,
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */, FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */,
FCF437E8214B6DDB00943429 /* Multi-Predict-ViewController.swift in Sources */, FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */,
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */, FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC803BCD214D27930094B8E5 /* FPSCounter.swift in Sources */, FC803BCD214D27930094B8E5 /* FPSCounter.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */, FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
</AdditionalOptions> </AdditionalOptions>
</TestAction> </TestAction>
<LaunchAction <LaunchAction
buildConfiguration = "Release" buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB" selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB" selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0" launchStyle = "0"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
<!--Multi Predict View Controller--> <!--Multi Predict View Controller-->
<scene sceneID="ec4-AW-9Vs"> <scene sceneID="ec4-AW-9Vs">
<objects> <objects>
<viewController id="Vwd-lt-764" customClass="Multi_Predict_ViewController" customModule="paddle_mobile_demo" customModuleProvider="target" sceneMemberID="viewController"> <viewController id="Vwd-lt-764" customClass="MultiPredictViewController" customModule="paddle_mobile_demo" customModuleProvider="target" sceneMemberID="viewController">
<view key="view" contentMode="scaleToFill" id="55D-rz-Ex6"> <view key="view" contentMode="scaleToFill" id="55D-rz-Ex6">
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/> <rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/> <autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
...@@ -50,7 +50,7 @@ ...@@ -50,7 +50,7 @@
<imageView userInteractionEnabled="NO" contentMode="scaleAspectFit" horizontalHuggingPriority="251" verticalHuggingPriority="251" translatesAutoresizingMaskIntoConstraints="NO" id="ZZh-fw-LwK"> <imageView userInteractionEnabled="NO" contentMode="scaleAspectFit" horizontalHuggingPriority="251" verticalHuggingPriority="251" translatesAutoresizingMaskIntoConstraints="NO" id="ZZh-fw-LwK">
<rect key="frame" x="0.0" y="20" width="225" height="247"/> <rect key="frame" x="0.0" y="20" width="225" height="247"/>
</imageView> </imageView>
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="Thread:" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="2EB-m2-a3L"> <label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="Platform:" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="2EB-m2-a3L">
<rect key="frame" x="10" y="538" width="68" height="24"/> <rect key="frame" x="10" y="538" width="68" height="24"/>
<constraints> <constraints>
<constraint firstAttribute="width" constant="68" id="Q5J-tq-JSX"/> <constraint firstAttribute="width" constant="68" id="Q5J-tq-JSX"/>
...@@ -236,7 +236,7 @@ ...@@ -236,7 +236,7 @@
</viewController> </viewController>
<placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/> <placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/>
</objects> </objects>
<point key="canvasLocation" x="-1543.2" y="-147.07646176911544"/> <point key="canvasLocation" x="-1127" y="-3"/>
</scene> </scene>
</scenes> </scenes>
<resources> <resources>
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
import UIKit import UIKit
import paddle_mobile import paddle_mobile
class Multi_Predict_ViewController: UIViewController { class MultiPredictViewController: UIViewController {
var runner1: Runner! var runner1: Runner!
var runner2: Runner! var runner2: Runner!
override func viewDidLoad() { override func viewDidLoad() {
......
...@@ -18,8 +18,8 @@ import CoreMedia ...@@ -18,8 +18,8 @@ import CoreMedia
import paddle_mobile import paddle_mobile
import MetalPerformanceShaders import MetalPerformanceShaders
let platform: Platform = .GPU var platform: Platform = .GPU
let threadSupport = [1] let threadSupport: [(Platform, String)] = [(.GPU, "GPU"), (.CPU, "CPU")]
//.mobilenet_ssd : Runner.init(inNet: MobileNet_ssd_hand.init(device: MetalHelper.shared.device), commandQueue: MetalHelper.shared.queue, inPlatform: platform), //.mobilenet_ssd : Runner.init(inNet: MobileNet_ssd_hand.init(device: MetalHelper.shared.device), commandQueue: MetalHelper.shared.queue, inPlatform: platform),
let modelHelperMap: [SupportModel : Runner] = [ let modelHelperMap: [SupportModel : Runner] = [
...@@ -28,6 +28,8 @@ let modelHelperMap: [SupportModel : Runner] = [ ...@@ -28,6 +28,8 @@ let modelHelperMap: [SupportModel : Runner] = [
//, .genet : Genet.init() //, .genet : Genet.init()
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()] //let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
let netSupport: [SupportModel : Net] = [.genet : Genet.init(device: MetalHelper.shared.device), .mobilenet_ssd_ar : MobileNet_ssd_AR.init(device: MetalHelper.shared.device)]
enum SupportModel: String{ enum SupportModel: String{
// case mobilenet = "mobilenet" // case mobilenet = "mobilenet"
// case mobilenet_ssd = "mobilenetssd" // case mobilenet_ssd = "mobilenetssd"
...@@ -55,17 +57,28 @@ class ViewController: UIViewController { ...@@ -55,17 +57,28 @@ class ViewController: UIViewController {
var modelType: SupportModel = SupportModel.supportedModels()[0] var modelType: SupportModel = SupportModel.supportedModels()[0]
var toPredictTexture: MTLTexture? var toPredictTexture: MTLTexture?
var runner: Runner { var runner: Runner!
get {
return modelHelperMap[modelType] ?! " has no this type "
}
set {
}
}
var threadNum = 1 var threadNum = 1
@IBAction func loadAct(_ sender: Any) { @IBAction func loadAct(_ sender: Any) {
runner = Runner.init(inNet: netSupport[modelType]!, commandQueue: MetalHelper.shared.queue, inPlatform: platform)
if platform == .CPU {
if inputPointer == nil {
inputPointer = runner.preproccess(image: selectImage!.cgImage!)
}
} else if platform == .GPU {
if self.toPredictTexture == nil {
runner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture
}
}
} else {
fatalError( " unsupport " )
}
if runner.load() { if runner.load() {
print(" load success ! ") print(" load success ! ")
} else { } else {
...@@ -128,6 +141,7 @@ class ViewController: UIViewController { ...@@ -128,6 +141,7 @@ class ViewController: UIViewController {
for _ in 0..<10 { for _ in 0..<10 {
runner.predict(inputPointer: inInputPointer) { (success, res) in runner.predict(inputPointer: inInputPointer) { (success, res) in
res?.releaseOutput()
} }
} }
...@@ -146,6 +160,7 @@ class ViewController: UIViewController { ...@@ -146,6 +160,7 @@ class ViewController: UIViewController {
} }
} }
} }
res?.releaseOutput()
} }
} }
} }
...@@ -168,15 +183,15 @@ class ViewController: UIViewController { ...@@ -168,15 +183,15 @@ class ViewController: UIViewController {
selectImage = UIImage.init(named: "hand.jpg") selectImage = UIImage.init(named: "hand.jpg")
selectImageView.image = selectImage selectImageView.image = selectImage
if platform == .CPU { // if platform == .CPU {
inputPointer = runner.preproccess(image: selectImage!.cgImage!) // inputPointer = runner.preproccess(image: selectImage!.cgImage!)
} else if platform == .GPU { // } else if platform == .GPU {
runner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in // runner.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture // self?.toPredictTexture = texture
} // }
} else { // } else {
fatalError( " unsupport " ) // fatalError( " unsupport " )
} // }
// videoCapture = VideoCapture.init(device: MetalHelper.shared.device, orientation: .portrait, position: .back) // videoCapture = VideoCapture.init(device: MetalHelper.shared.device, orientation: .portrait, position: .back)
// videoCapture.fps = 30 // videoCapture.fps = 30
...@@ -219,7 +234,7 @@ extension ViewController: UIPickerViewDataSource, UIPickerViewDelegate{ ...@@ -219,7 +234,7 @@ extension ViewController: UIPickerViewDataSource, UIPickerViewDelegate{
if pickerView == modelPickerView { if pickerView == modelPickerView {
return SupportModel.supportedModels()[row].rawValue return SupportModel.supportedModels()[row].rawValue
} else if pickerView == threadPickerView { } else if pickerView == threadPickerView {
return "\(threadSupport[row])" return threadSupport[row].1
} else { } else {
fatalError() fatalError()
} }
...@@ -229,7 +244,8 @@ extension ViewController: UIPickerViewDataSource, UIPickerViewDelegate{ ...@@ -229,7 +244,8 @@ extension ViewController: UIPickerViewDataSource, UIPickerViewDelegate{
if pickerView == modelPickerView { if pickerView == modelPickerView {
self.modelType = SupportModel.supportedModels()[row] self.modelType = SupportModel.supportedModels()[row]
} else if pickerView == threadPickerView { } else if pickerView == threadPickerView {
self.threadNum = threadSupport[row]
platform = threadSupport[row].0
} else { } else {
fatalError() fatalError()
} }
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
</AdditionalOptions> </AdditionalOptions>
</TestAction> </TestAction>
<LaunchAction <LaunchAction
buildConfiguration = "Release" buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB" selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB" selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0" launchStyle = "0"
......
...@@ -17,7 +17,17 @@ ...@@ -17,7 +17,17 @@
#import <CoreImage/CoreImage.h> #import <CoreImage/CoreImage.h>
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
@interface PaddleMobile : NSObject @interface PaddleMobileCPUResult: NSObject
@property (assign, nonatomic, readonly) float *output;
@property (assign, nonatomic, readonly) int outputSize;
-(void)releaseOutput;
@end
@interface PaddleMobileCPU : NSObject
/* /*
创建对象 创建对象
...@@ -42,25 +52,8 @@ ...@@ -42,25 +52,8 @@
andModelParamsLen:(size_t)combinedParamsLen andModelParamsLen:(size_t)combinedParamsLen
andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf; andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf;
/*
* 进行预测, means 和 scale 为训练模型时的预处理参数, 如训练时没有做这些预处理则直接使用 predict
*/
- (NSArray *)predict:(CGImageRef)image
dim:(NSArray<NSNumber *> *)dim
means:(NSArray<NSNumber *> *)means
scale:(float)scale;
/* /*
* 预测输入 * 对图像进行预处理, 需要外部开辟 output 内存, 外部释放 output 内存
* */
- (NSArray *)predictInput:(float *)input
dim:(NSArray<NSNumber *> *)dim
means:(NSArray<NSNumber *> *)means
scale:(float)scale;
/*
* 对图像进行预处理
* */ * */
-(void)preprocess:(CGImageRef)image -(void)preprocess:(CGImageRef)image
output:(float *)output output:(float *)output
...@@ -68,6 +61,22 @@ ...@@ -68,6 +61,22 @@
scale:(float)scale scale:(float)scale
dim:(NSArray<NSNumber *> *)dim; dim:(NSArray<NSNumber *> *)dim;
/*
* 预测预处理后的数据, 返回结果使用结束需要调用其 realseOutput 函数进行释放
* */
- (PaddleMobileCPUResult *)predictInput:(float *)input
dim:(NSArray<NSNumber *> *)dim;
/*
进行预测, means 和 scale 为训练模型时的预处理参数, 如训练时没有做这些预处理则直接使用 predict
*/
- (NSArray *)predict:(CGImageRef)image dim:(NSArray<NSNumber *> *)dim means:(NSArray<NSNumber *> *)means scale:(float)scale;
/*
进行预测, 默认 means 为 0, scale 为 1.0
*/
- (NSArray *)predict:(CGImageRef)image dim:(NSArray<NSNumber *> *)dim;
/* /*
清理内存 清理内存
*/ */
......
...@@ -82,4 +82,60 @@ public class MobileNet_ssd_AR: Net{ ...@@ -82,4 +82,60 @@ public class MobileNet_ssd_AR: Net{
// print(resultHolder.result![0]) // print(resultHolder.result![0])
return resultHolder return resultHolder
} }
override func updateProgram(program: Program) {
for i in [56, 66, 76, 86, 93, 99] {
let opDesc = program.programDesc.blocks[0].ops[i]
let output = opDesc.outputs["Out"]!.first!
let v = program.scope[output]!
let originTexture = v as! Texture<Float32>
originTexture.tensorDim = Dim.init(inDim: [originTexture.tensorDim[1] / 7, originTexture.tensorDim[0] * 7])
originTexture.dim = Dim.init(inDim: [1, 1, originTexture.dim[3] / 7, originTexture.dim[2] * 7])
originTexture.padToFourDim = Dim.init(inDim: [1, 1, originTexture.padToFourDim[3] / 7, originTexture.padToFourDim[2] * 7])
program.scope[output] = originTexture
if i == 99 {
opDesc.attrs["axis"] = 0
} else {
opDesc.attrs["shape"] = originTexture.tensorDim.dims.map { Int32($0) }
}
}
for i in [58, 59, 88, 89, 95, 96, 68, 69, 78, 79] {
let opDesc = program.programDesc.blocks[0].ops[i]
let output = opDesc.outputs["Out"]!.first!
let v = program.scope[output]!
let originTexture = v as! Texture<Float32>
originTexture.tensorDim = Dim.init(inDim: [originTexture.tensorDim[1], originTexture.tensorDim[2]])
opDesc.attrs["shape"] = originTexture.tensorDim.dims.map { Int32($0) }
}
for i in [60, 101, 90, 97, 70, 80] {
let opDesc = program.programDesc.blocks[0].ops[i]
let output = opDesc.outputs["Out"]!.first!
let v = program.scope[output]!
let originTexture = v as! Texture<Float32>
originTexture.tensorDim = Dim.init(inDim: [originTexture.tensorDim[1], originTexture.tensorDim[2]])
opDesc.attrs["axis"] = (opDesc.attrs["axis"]! as! Int) - 1
}
for i in [102] {
let opDesc = program.programDesc.blocks[0].ops[i]
for output in opDesc.outputs["Out"]! {
let v = program.scope[output]!
let originTexture = v as! Texture<Float32>
originTexture.tensorDim = Dim.init(inDim: [originTexture.tensorDim[1], originTexture.tensorDim[2]])
}
opDesc.attrs["axis"] = (opDesc.attrs["axis"]! as! Int) - 1
print(" split axis \(opDesc.attrs["axis"])")
}
// 99
}
} }
...@@ -55,4 +55,8 @@ public class Net: NSObject { ...@@ -55,4 +55,8 @@ public class Net: NSObject {
@objc public init(device: MTLDevice) { @objc public init(device: MTLDevice) {
super.init() super.init()
} }
func updateProgram(program: Program) {
}
} }
...@@ -125,7 +125,6 @@ struct ConcatParam { ...@@ -125,7 +125,6 @@ struct ConcatParam {
#undef R #undef R
#undef V #undef V
#define V VNORMAL #define V VNORMAL
#define R 4 #define R 4
#define N 2 #define N 2
...@@ -138,3 +137,35 @@ struct ConcatParam { ...@@ -138,3 +137,35 @@ struct ConcatParam {
#undef N #undef N
#undef R #undef R
#undef V #undef V
#define V VY
#define R 2
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VY
#define R 2
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
...@@ -48,3 +48,17 @@ struct SplitParam { ...@@ -48,3 +48,17 @@ struct SplitParam {
#undef R #undef R
#undef V #undef V
//// ssd-ar: (R=2, N=2, V=y)
#define V VY
#define R 2
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
#undef R
#undef V
...@@ -39,6 +39,7 @@ class SplitParam<P: PrecisionType>: OpParam { ...@@ -39,6 +39,7 @@ class SplitParam<P: PrecisionType>: OpParam {
throw error throw error
} }
} }
var axis: Int var axis: Int
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
......
...@@ -34,7 +34,7 @@ public class Runner: NSObject { ...@@ -34,7 +34,7 @@ public class Runner: NSObject {
public let net: Net public let net: Net
let device: MTLDevice? let device: MTLDevice?
let platform: Platform let platform: Platform
var cpuPaddleMobile: PaddleMobile? var cpuPaddleMobile: PaddleMobileCPU?
let numel: Int let numel: Int
let meansNumber: [NSNumber] let meansNumber: [NSNumber]
...@@ -54,7 +54,7 @@ public class Runner: NSObject { ...@@ -54,7 +54,7 @@ public class Runner: NSObject {
textureLoader = MTKTextureLoader.init(device: inDevice) textureLoader = MTKTextureLoader.init(device: inDevice)
} }
if platform == .CPU { if platform == .CPU {
cpuPaddleMobile = PaddleMobile.init() cpuPaddleMobile = PaddleMobileCPU.init()
} }
numel = net.dim.n * net.dim.c * net.dim.h * net.dim.w numel = net.dim.n * net.dim.c * net.dim.h * net.dim.w
meansNumber = net.means.map { NSNumber.init(value: $0) } meansNumber = net.means.map { NSNumber.init(value: $0) }
...@@ -76,6 +76,7 @@ public class Runner: NSObject { ...@@ -76,6 +76,7 @@ public class Runner: NSObject {
let loader = Loader<Float32>.init() let loader = Loader<Float32>.init()
do { do {
program = try loader.load(device: inDevice, modelPath: net.modelPath, paraPath: net.paramPath) program = try loader.load(device: inDevice, modelPath: net.modelPath, paraPath: net.paramPath)
net.updateProgram(program: program!)
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!) executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!)
} catch let error { } catch let error {
print(error) print(error)
...@@ -87,12 +88,13 @@ public class Runner: NSObject { ...@@ -87,12 +88,13 @@ public class Runner: NSObject {
return true return true
} }
@objc public func predict(inputPointer: UnsafeMutablePointer<Float32>, completion: @escaping ( _ success: Bool, _ resultArray: [Float32]) -> Void) { @objc public func predict(inputPointer: UnsafeMutablePointer<Float32>, completion: @escaping ( _ success: Bool, _ result: PaddleMobileCPUResult?) -> Void) {
guard let res = cpuPaddleMobile?.predictInput(inputPointer, dim: dimsNum, means: meansNumber, scale: net.scale) else {
completion(false, []) guard let res = cpuPaddleMobile?.predictInput(inputPointer, dim: dimsNum) else {
completion(false, nil)
return return
} }
completion(true, res.map { ($0 as! NSNumber).floatValue }) completion(true, res)
} }
/** /**
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
struct BlockDesc { class BlockDesc {
let index: Int let index: Int
let parentIndex: Int let parentIndex: Int
let vars: [VarDesc] let vars: [VarDesc]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
struct OpDesc { class OpDesc {
let inputs: [String : [String]] let inputs: [String : [String]]
var paraInputs: [String : [String]] var paraInputs: [String : [String]]
var outputs: [String : [String]] var outputs: [String : [String]]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
public struct Program { public class Program {
let paramPath: String let paramPath: String
let programDesc: ProgramDesc let programDesc: ProgramDesc
let scope: Scope let scope: Scope
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
public struct ProgramDesc { public class ProgramDesc {
var blocks: [BlockDesc] = [] var blocks: [BlockDesc] = []
init(protoProgram: PaddleMobile_Framework_Proto_ProgramDesc) { init(protoProgram: PaddleMobile_Framework_Proto_ProgramDesc) {
for block in protoProgram.blocks { for block in protoProgram.blocks {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
struct TensorDesc { class TensorDesc {
let dims: [Int] let dims: [Int]
let dataType: VarTypeType let dataType: VarTypeType
let dataLayout: DataLayout = DataLayout.NCHW() let dataLayout: DataLayout = DataLayout.NCHW()
......
...@@ -56,7 +56,7 @@ enum VarTypeType: Int { ...@@ -56,7 +56,7 @@ enum VarTypeType: Int {
} }
} }
struct VarDesc { class VarDesc {
let name: String let name: String
let persistable: Bool let persistable: Bool
let type: VarTypeType let type: VarTypeType
......
...@@ -76,11 +76,23 @@ public class Executor<P: PrecisionType> { ...@@ -76,11 +76,23 @@ public class Executor<P: PrecisionType> {
program = inProgram program = inProgram
device = inDevice device = inDevice
queue = inQueue queue = inQueue
// print("before for ")
//print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
//block.ops.count //block.ops.count
for i in 0..<block.ops.count { for i in 0..<block.ops.count {
let op = block.ops[i] let op = block.ops[i]
do { do {
// print("in for i \(i): ")
// print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
//
// if i == 56 {
// print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
//
// }
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
ops.append(op) ops.append(op)
} catch let error { } catch let error {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册