diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
index 69fb820cd5ff34633d0084d697aa8d986f0117d1..e249c3aea909869aec73f5ac8fa3d8ca63382c0f 100644
--- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
+++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift
@@ -17,11 +17,11 @@ import MetalKit
import paddle_mobile
import MetalPerformanceShaders
-let platform: Platform = .CPU
+let platform: Platform = .GPU
let threadSupport = [1]
-let modelHelperMap: [SupportModel : Runner] = [.mobilenet_ssd : Runner.init(inNet: MobileNet_ssd_hand.init(), commandQueue: MetalHelper.shared.queue, inPlatform: platform),
- .genet : Runner.init(inNet: Genet.init(), commandQueue: MetalHelper.shared.queue, inPlatform: platform)]
+let modelHelperMap: [SupportModel : Runner] = [.mobilenet_ssd : Runner.init(inNet: MobileNet_ssd_hand.init(device: MetalHelper.shared.device), commandQueue: MetalHelper.shared.queue, inPlatform: platform),
+ .genet : Runner.init(inNet: Genet.init(device: MetalHelper.shared.device), commandQueue: MetalHelper.shared.queue, inPlatform: platform)]
//, .genet : Genet.init()
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
@@ -48,6 +48,7 @@ class ViewController: UIViewController {
var toPredictTexture: MTLTexture?
var runner: Runner {
+
get {
return modelHelperMap[modelType] ?! " has no this type "
}
diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h b/metal/paddle-mobile-demo/paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h
index 7a56ca282a3bbd4743e7440934efd9f16068a6b6..25434fa4b69ae0a362b0811291a49d91d4e13dc9 100644
--- a/metal/paddle-mobile-demo/paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h
+++ b/metal/paddle-mobile-demo/paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h
@@ -4,4 +4,4 @@
-#import "CPUCompute.h"
+//#import
diff --git a/metal/paddle-mobile/paddle-mobile/Genet.swift b/metal/paddle-mobile/paddle-mobile/Genet.swift
index 2479cdf26ea09c933248c04ae99173ff0cfb7164..40c190ef875f2fa559eec8c1999de98694d793e1 100644
--- a/metal/paddle-mobile/paddle-mobile/Genet.swift
+++ b/metal/paddle-mobile/paddle-mobile/Genet.swift
@@ -14,8 +14,8 @@
import Foundation
-class Genet: Net {
- @objc override init(device: MTLDevice) {
+public class Genet: Net {
+ @objc public override init(device: MTLDevice) {
super.init(device: device)
means = [128.0, 128.0, 128.0]
scale = 0.017
@@ -34,7 +34,7 @@ class Genet: Net {
}
}
- override func resultStr(res: [Float]) -> String {
+ override public func resultStr(res: [Float]) -> String {
return " \(Array(res.suffix(10))) ... "
}
diff --git a/metal/paddle-mobile/paddle-mobile/MobileNetSSD.swift b/metal/paddle-mobile/paddle-mobile/MobileNetSSD.swift
index a7901d05522aadac713a87ccb0861c13ea684ce4..47003043d93e4f685cb4a1adaeb897b2af19fb9d 100644
--- a/metal/paddle-mobile/paddle-mobile/MobileNetSSD.swift
+++ b/metal/paddle-mobile/paddle-mobile/MobileNetSSD.swift
@@ -14,8 +14,8 @@
import Foundation
-class MobileNet_ssd_hand: Net{
- @objc override init(device: MTLDevice) {
+public class MobileNet_ssd_hand: Net{
+ @objc public override init(device: MTLDevice) {
super.init(device: device)
means = [123.68, 116.78, 103.94]
scale = 0.017
@@ -34,7 +34,7 @@ class MobileNet_ssd_hand: Net{
}
}
- override func resultStr(res: [Float]) -> String {
+ override public func resultStr(res: [Float]) -> String {
return " \(res)"
}
diff --git a/metal/paddle-mobile/paddle-mobile/PaddleMobile.swift b/metal/paddle-mobile/paddle-mobile/PaddleMobile.swift
index d3f089e9021f6e7af7e42de8ef1e4f9ad0c65d44..be768f665f1d3ea315cb129e75ffae48038b3f93 100644
--- a/metal/paddle-mobile/paddle-mobile/PaddleMobile.swift
+++ b/metal/paddle-mobile/paddle-mobile/PaddleMobile.swift
@@ -33,7 +33,7 @@ public class Net: NSObject {
var modelPath: String = ""
var paramPath: String = ""
var modelDir: String = ""
- func resultStr(res: [Float]) -> String {
+ public func resultStr(res: [Float]) -> String {
fatalError()
}
func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
diff --git a/metal/paddle-mobile/paddle-mobile/PaddleMobileGPU.h b/metal/paddle-mobile/paddle-mobile/PaddleMobileGPU.h
index 459a5b4bc1808eb4b9f31f24f9c4858696f644ec..0f3235d5e87539663fe69fa9b7ca8a0458278cad 100644
--- a/metal/paddle-mobile/paddle-mobile/PaddleMobileGPU.h
+++ b/metal/paddle-mobile/paddle-mobile/PaddleMobileGPU.h
@@ -16,9 +16,9 @@
#import
typedef enum : NSUInteger {
- MobileNet,
- MobileNetSSD,
- Genet,
+ MobileNetType,
+ MobileNetSSDType,
+ GenetType,
} NetType;
@interface ModelConfig: NSObject