提交 152d6f85 编写于 作者: L liuruilong

run genet

上级 b8587e31
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; }; FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; };
FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; }; FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; };
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; }; FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; };
FC8CFEDB21351F5D0094D569 /* params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFED921351F5D0094D569 /* params */; }; FC8CFEDF213521C10094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDD213521C10094D569 /* genet_model */; };
FC8CFEDC21351F5D0094D569 /* model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDA21351F5D0094D569 /* model */; }; FC8CFEE0213521C10094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDE213521C10094D569 /* genet_params */; };
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; };
FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */ = {isa = PBXBuildFile; fileRef = FC918190211DBC3500B6F354 /* paddle-mobile.png */; }; FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */ = {isa = PBXBuildFile; fileRef = FC918190211DBC3500B6F354 /* paddle-mobile.png */; };
FC918193211DC70500B6F354 /* iphone.JPG in Resources */ = {isa = PBXBuildFile; fileRef = FC918192211DC70500B6F354 /* iphone.JPG */; }; FC918193211DC70500B6F354 /* iphone.JPG in Resources */ = {isa = PBXBuildFile; fileRef = FC918192211DC70500B6F354 /* iphone.JPG */; };
FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FCA3A16021313E1F00084FE5 /* hand.jpg */; }; FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FCA3A16021313E1F00084FE5 /* hand.jpg */; };
...@@ -66,8 +67,9 @@ ...@@ -66,8 +67,9 @@
FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; }; FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; };
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; }; FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; };
FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; }; FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC8CFED921351F5D0094D569 /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; }; FC8CFEDD213521C10094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = "<group>"; };
FC8CFEDA21351F5D0094D569 /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; }; FC8CFEDE213521C10094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = "<group>"; };
FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC918190211DBC3500B6F354 /* paddle-mobile.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "paddle-mobile.png"; sourceTree = "<group>"; }; FC918190211DBC3500B6F354 /* paddle-mobile.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "paddle-mobile.png"; sourceTree = "<group>"; };
FC918192211DC70500B6F354 /* iphone.JPG */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = iphone.JPG; sourceTree = "<group>"; }; FC918192211DC70500B6F354 /* iphone.JPG */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = iphone.JPG; sourceTree = "<group>"; };
FCA3A16021313E1F00084FE5 /* hand.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = hand.jpg; sourceTree = "<group>"; }; FCA3A16021313E1F00084FE5 /* hand.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = hand.jpg; sourceTree = "<group>"; };
...@@ -137,7 +139,6 @@ ...@@ -137,7 +139,6 @@
FC0E2C2020EDC03B009C1FAC /* models */, FC0E2C2020EDC03B009C1FAC /* models */,
FC0E2C1D20EDC030009C1FAC /* images */, FC0E2C1D20EDC030009C1FAC /* images */,
FC039B8120E11C550081E9F8 /* AppDelegate.swift */, FC039B8120E11C550081E9F8 /* AppDelegate.swift */,
FC013927210204A3008100E3 /* PreProcessKernel.metal */,
FC039B8320E11C550081E9F8 /* ViewController.swift */, FC039B8320E11C550081E9F8 /* ViewController.swift */,
FC039B8520E11C550081E9F8 /* Main.storyboard */, FC039B8520E11C550081E9F8 /* Main.storyboard */,
FC039B8820E11C560081E9F8 /* Assets.xcassets */, FC039B8820E11C560081E9F8 /* Assets.xcassets */,
...@@ -164,7 +165,7 @@ ...@@ -164,7 +165,7 @@
FC0E2C2020EDC03B009C1FAC /* models */ = { FC0E2C2020EDC03B009C1FAC /* models */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC8CFED821351F5D0094D569 /* enet */, FC8CFED821351F5D0094D569 /* genet */,
FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */, FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */,
FCD04E6020F3146A0007374F /* mobilenet */, FCD04E6020F3146A0007374F /* mobilenet */,
); );
...@@ -175,23 +176,25 @@ ...@@ -175,23 +176,25 @@
FC8CFED2213519540094D569 /* Net */ = { FC8CFED2213519540094D569 /* Net */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC013927210204A3008100E3 /* PreProcessKernel.metal */,
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */, FCBCCC542122EF5400D94F7E /* MetalHelper.swift */,
FC3C800E2133F46600D1295E /* MobileNetSSD.swift */, FC3C800E2133F46600D1295E /* MobileNetSSD.swift */,
FC3C80102133F4AB00D1295E /* MobileNet.swift */, FC3C80102133F4AB00D1295E /* MobileNet.swift */,
FC27990F21341CE5000B6BAD /* Net.swift */, FC27990F21341CE5000B6BAD /* Net.swift */,
FC27991221343A3A000B6BAD /* CPUCompute.mm */, FC27991221343A3A000B6BAD /* CPUCompute.mm */,
FC27991421343A46000B6BAD /* CPUCompute.h */, FC27991421343A46000B6BAD /* CPUCompute.h */,
FC8CFEE1213524EA0094D569 /* Genet.swift */,
); );
path = Net; path = Net;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FC8CFED821351F5D0094D569 /* enet */ = { FC8CFED821351F5D0094D569 /* genet */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC8CFED921351F5D0094D569 /* params */, FC8CFEDD213521C10094D569 /* genet_model */,
FC8CFEDA21351F5D0094D569 /* model */, FC8CFEDE213521C10094D569 /* genet_params */,
); );
path = enet; path = genet;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */ = { FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */ = {
...@@ -274,10 +277,9 @@ ...@@ -274,10 +277,9 @@
isa = PBXResourcesBuildPhase; isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC8CFEDB21351F5D0094D569 /* params in Resources */,
FCD04E6320F3146B0007374F /* params in Resources */, FCD04E6320F3146B0007374F /* params in Resources */,
FC8CFEDC21351F5D0094D569 /* model in Resources */,
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */, FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */,
FC8CFEE0213521C10094D569 /* genet_params in Resources */,
FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */, FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */,
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */,
FCBCCC522122EEDC00D94F7E /* ssd_hand_params in Resources */, FCBCCC522122EEDC00D94F7E /* ssd_hand_params in Resources */,
...@@ -285,6 +287,7 @@ ...@@ -285,6 +287,7 @@
FC918193211DC70500B6F354 /* iphone.JPG in Resources */, FC918193211DC70500B6F354 /* iphone.JPG in Resources */,
FCDFD41B211D91C7005AB38B /* synset.txt in Resources */, FCDFD41B211D91C7005AB38B /* synset.txt in Resources */,
FCD04E6420F3146B0007374F /* model in Resources */, FCD04E6420F3146B0007374F /* model in Resources */,
FC8CFEDF213521C10094D569 /* genet_model in Resources */,
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */, FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */,
FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */, FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */,
FCBCCC532122EEDC00D94F7E /* ssd_hand_model in Resources */, FCBCCC532122EEDC00D94F7E /* ssd_hand_model in Resources */,
...@@ -340,6 +343,7 @@ ...@@ -340,6 +343,7 @@
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */, FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */,
FC27991021341CE5000B6BAD /* Net.swift in Sources */, FC27991021341CE5000B6BAD /* Net.swift in Sources */,
FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */,
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */, FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */, FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */,
FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */, FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */,
...@@ -490,19 +494,19 @@ ...@@ -490,19 +494,19 @@
buildSettings = { buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution"; CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Manual; CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6T9LLJKSM4; DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
); );
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 4.0; SWIFT_VERSION = 4.0;
...@@ -516,19 +520,19 @@ ...@@ -516,19 +520,19 @@
buildSettings = { buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_MODULES = YES;
CODE_SIGN_IDENTITY = "iPhone Distribution"; CODE_SIGN_IDENTITY = "iPhone Developer";
CODE_SIGN_STYLE = Manual; CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6T9LLJKSM4; DEVELOPMENT_TEAM = A798K58VVL;
INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
); );
PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME)"; PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; PROVISIONING_PROFILE = "";
PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h";
SWIFT_VERSION = 4.0; SWIFT_VERSION = 4.0;
TARGETED_DEVICE_FAMILY = "1,2"; TARGETED_DEVICE_FAMILY = "1,2";
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
</AdditionalOptions> </AdditionalOptions>
</TestAction> </TestAction>
<LaunchAction <LaunchAction
buildConfiguration = "Release" buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB" selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB" selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0" launchStyle = "0"
......
...@@ -310,6 +310,7 @@ void MultiClassNMSCompute(NMSParam *param) { ...@@ -310,6 +310,7 @@ void MultiClassNMSCompute(NMSParam *param) {
for (int i = 0; i < param.output_size; ++i) { for (int i = 0; i < param.output_size; ++i) {
[output addObject:[NSNumber numberWithFloat:param.output[i]]]; [output addObject:[NSNumber numberWithFloat:param.output[i]]];
} }
delete param.output;
return output; return output;
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
import paddle_mobile
class Genet: Net {
var program: Program?
var executor: Executor<Float32>?
let except: Int = 0
class GenetPreProccess: CusomKernel {
init(device: MTLDevice) {
let s = CusomKernel.Shape.init(inWidth: 128, inHeight: 128, inChannel: 3)
super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, usePaddleMobileLib: false)
}
}
func resultStr(res: [Float]) -> String {
return " 哈哈 还没好 genet !";
}
var preprocessKernel: CusomKernel
let dim = [1, 128, 128, 3]
let modelPath: String
let paramPath: String
let modelDir: String
init() {
modelPath = Bundle.main.path(forResource: "genet_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "genet_params", ofType: nil) ?! "para null"
modelDir = ""
preprocessKernel = GenetPreProccess.init(device: MetalHelper.shared.device)
}
}
...@@ -26,7 +26,7 @@ class MobileNet: Net{ ...@@ -26,7 +26,7 @@ class MobileNet: Net{
class MobilenetPreProccess: CusomKernel { class MobilenetPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = CusomKernel.Shape.init(inWidth: 224, inHeight: 224, inChannel: 3) let s = CusomKernel.Shape.init(inWidth: 224, inHeight: 224, inChannel: 3)
super.init(device: device, inFunctionName: "preprocess", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilenet_preprocess", outputDim: s, usePaddleMobileLib: false)
} }
} }
......
...@@ -69,7 +69,6 @@ class MobileNet_ssd_hand: Net{ ...@@ -69,7 +69,6 @@ class MobileNet_ssd_hand: Net{
let output: [Float32] = result.map { $0.floatValue } let output: [Float32] = result.map { $0.floatValue }
return output return output
} }
......
...@@ -21,18 +21,6 @@ import paddle_mobile ...@@ -21,18 +21,6 @@ import paddle_mobile
import MetalPerformanceShaders import MetalPerformanceShaders
let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.init()]
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
enum SupportModel: String{
// case mobilenet = "mobilenet"
case mobilenet_ssd = "mobilenetssd"
static func supportedModels() -> [SupportModel] {
//.mobilenet,
return [.mobilenet_ssd]
}
}
protocol Net { protocol Net {
var program: Program? { get set } var program: Program? { get set }
var executor: Executor<Float32>? { get set } var executor: Executor<Float32>? { get set }
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
using namespace metal; using namespace metal;
kernel void preprocess( kernel void mobilenet_preprocess(
texture2d<float, access::read> inTexture [[texture(0)]], texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]], texture2d<float, access::write> outTexture [[texture(1)]],
uint2 gid [[thread_position_in_grid]]) uint2 gid [[thread_position_in_grid]])
...@@ -24,7 +24,7 @@ kernel void preprocess( ...@@ -24,7 +24,7 @@ kernel void preprocess(
outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid); outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid);
} }
kernel void preprocess_half( kernel void mobilenet_preprocess_half(
texture2d<half, access::read> inTexture [[texture(0)]], texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]], texture2d<half, access::write> outTexture [[texture(1)]],
uint2 gid [[thread_position_in_grid]]) uint2 gid [[thread_position_in_grid]])
...@@ -68,5 +68,17 @@ kernel void mobilenet_ssd_preprocess_half( ...@@ -68,5 +68,17 @@ kernel void mobilenet_ssd_preprocess_half(
} }
kernel void genet_preprocess(
texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]],
uint2 gid [[thread_position_in_grid]])
{
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
const auto means = float4(123.68f, 116.78f, 103.94f, 0.0f);
const float4 inColor = (inTexture.read(gid) * 255.0 - means) * 0.017;
outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid);
}
...@@ -19,6 +19,20 @@ import MetalPerformanceShaders ...@@ -19,6 +19,20 @@ import MetalPerformanceShaders
let threadSupport = [1] let threadSupport = [1]
let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.init(), .genet : Genet.init()]
//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()]
enum SupportModel: String{
// case mobilenet = "mobilenet"
case mobilenet_ssd = "mobilenetssd"
case genet = "enet"
static func supportedModels() -> [SupportModel] {
//.mobilenet,
return [.mobilenet_ssd, .genet]
}
}
class ViewController: UIViewController { class ViewController: UIViewController {
@IBOutlet weak var resultTextView: UITextView! @IBOutlet weak var resultTextView: UITextView!
@IBOutlet weak var selectImageView: UIImageView! @IBOutlet weak var selectImageView: UIImageView!
......
...@@ -130,11 +130,6 @@ public class Executor<P: PrecisionType> { ...@@ -130,11 +130,6 @@ public class Executor<P: PrecisionType> {
// } // }
// return // return
// self.ops[testTo].delogOutput()
// self.ops[91].delogOutput()
// self.ops[92].delogOutput()
// self.ops[93].delogOutput()
let afterDate = Date.init() let afterDate = Date.init()
var resultHolder: ResultHolder<P> var resultHolder: ResultHolder<P>
......
...@@ -58,7 +58,9 @@ class OpCreator<P: PrecisionType> { ...@@ -58,7 +58,9 @@ class OpCreator<P: PrecisionType> {
gDwConvBnReluType : DwConvBNReluOp<P>.creat, gDwConvBnReluType : DwConvBNReluOp<P>.creat,
gMulticlassNMSType : MulticlassNMSOp<P>.creat, gMulticlassNMSType : MulticlassNMSOp<P>.creat,
gTransposeType : TransposeOp<P>.creat, gTransposeType : TransposeOp<P>.creat,
gPriorBoxType : PriorBoxOp<P>.creat] gPriorBoxType : PriorBoxOp<P>.creat,
gPreluType : PreluOp<P>.creat,
gConv2dTransposeType : ConvTransposeOp<P>.creat]
private init(){} private init(){}
} }
...@@ -50,7 +50,7 @@ protocol OpParam { ...@@ -50,7 +50,7 @@ protocol OpParam {
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T
static func inputAlpha<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType static func paramInputAlpha<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
} }
...@@ -63,9 +63,14 @@ extension OpParam { ...@@ -63,9 +63,14 @@ extension OpParam {
guard let mapKeys = map[key], mapKeys.count > 0 else { guard let mapKeys = map[key], mapKeys.count > 0 else {
throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty") throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty")
} }
guard let variant = from[mapKeys[0]], let v = variant as? VarType else { guard let variant = from[mapKeys[0]] else {
throw PaddleMobileError.paramError(message: mapKeys[0] + " not found in scope") throw PaddleMobileError.paramError(message: mapKeys[0] + " not found in scope")
} }
guard let v = variant as? VarType else {
throw PaddleMobileError.paramError(message: " type error")
}
return v return v
} }
...@@ -78,7 +83,7 @@ extension OpParam { ...@@ -78,7 +83,7 @@ extension OpParam {
} }
} }
static func inputAlpha<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType { static func paramInputAlpha<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do { do {
let alphaTensor: VarType = try getFirstTensor(key: "Alpha", map: inputs, from: from) let alphaTensor: VarType = try getFirstTensor(key: "Alpha", map: inputs, from: from)
return alphaTensor return alphaTensor
......
...@@ -137,6 +137,9 @@ let gBoxcoderType = "box_coder" ...@@ -137,6 +137,9 @@ let gBoxcoderType = "box_coder"
let gMulticlassNMSType = "multiclass_nms" let gMulticlassNMSType = "multiclass_nms"
let gConvBnReluType = "conv_bn_relu" let gConvBnReluType = "conv_bn_relu"
let gDwConvBnReluType = "depth_conv_bn_relu" let gDwConvBnReluType = "depth_conv_bn_relu"
let gPreluType = "prelu"
let gConv2dTransposeType = "conv2d_transpose"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]), gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
...@@ -156,4 +159,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out ...@@ -156,4 +159,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gConvBnReluType : (inputs: ["Input"], outputs: ["Out"]), gConvBnReluType : (inputs: ["Input"], outputs: ["Out"]),
gDwConvBnReluType : (inputs: ["Input"], outputs: ["Out"]), gDwConvBnReluType : (inputs: ["Input"], outputs: ["Out"]),
gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]), gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]),
gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"])] gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]),
gPreluType : (inputs: ["X"], outputs: ["Out"]),
gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"])
]
...@@ -34,7 +34,7 @@ class ConvAddBatchNormReluParam<P: PrecisionType>: OpParam { ...@@ -34,7 +34,7 @@ class ConvAddBatchNormReluParam<P: PrecisionType>: OpParam {
scale = try ConvAddBatchNormReluParam.inputScale(inputs: opDesc.paraInputs, from: inScope) scale = try ConvAddBatchNormReluParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
mean = try ConvAddBatchNormReluParam.inputMean(inputs: opDesc.paraInputs, from: inScope) mean = try ConvAddBatchNormReluParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.paraInputs, from: inScope) y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.inputs, from: inScope)
} catch let error { } catch let error {
throw error throw error
} }
......
...@@ -25,6 +25,7 @@ class ConvAddParam<P: PrecisionType>: OpParam { ...@@ -25,6 +25,7 @@ class ConvAddParam<P: PrecisionType>: OpParam {
paddings = try ConvAddParam.getAttr(key: "paddings", attrs: opDesc.attrs) paddings = try ConvAddParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvAddParam.getAttr(key: "dilations", attrs: opDesc.attrs) dilations = try ConvAddParam.getAttr(key: "dilations", attrs: opDesc.attrs)
groups = try ConvAddParam.getAttr(key: "groups", attrs: opDesc.attrs) groups = try ConvAddParam.getAttr(key: "groups", attrs: opDesc.attrs)
y = try ConvAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) y = try ConvAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch let error { } catch let error {
throw error throw error
......
...@@ -18,17 +18,27 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -18,17 +18,27 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch _ {
do {
inputYTexture = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
} catch let error {
throw error
}
}
do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
} }
let input: Texture<P>
let inputY: Tensor<P> var inputYTexture: Texture<P>?
var inputY: Tensor<P>?
var input: Texture<P>
var output: Texture<P> var output: Texture<P>
let axis: Int let axis: Int
} }
......
...@@ -16,11 +16,12 @@ import Foundation ...@@ -16,11 +16,12 @@ import Foundation
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable { class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ElementwiseAddParam<P>) { required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
super.init(device: device, inFunctionName: "elementwise_add") super.init(device: device, inFunctionName: "elementwise_add")
} param.output.initTexture(device: device, inTranspose: param.input.transpose)
}
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws { }
}
} }
...@@ -17,6 +17,7 @@ class PreluKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -17,6 +17,7 @@ class PreluKernel<P: PrecisionType>: Kernel, Computable{
} else { } else {
super.init(device: device, inFunctionName: "prelu_other") super.init(device: device, inFunctionName: "prelu_other")
} }
param.output.initTexture(device: device, inTranspose: param.input.transpose)
} }
func compute(commandBuffer: MTLCommandBuffer, param: PreluParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: PreluParam<P>) throws {
......
...@@ -21,7 +21,7 @@ class PreluParam<P: PrecisionType>: OpParam { ...@@ -21,7 +21,7 @@ class PreluParam<P: PrecisionType>: OpParam {
do { do {
input = try PreluParam.inputX(inputs: opDesc.inputs, from: inScope) input = try PreluParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try PreluParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try PreluParam.outputOut(outputs: opDesc.outputs, from: inScope)
alpha = try PreluParam.inputAlpha(inputs: opDesc.inputs, from: inScope) alpha = try PreluParam.paramInputAlpha(inputs: opDesc.paraInputs, from: inScope)
mode = try PreluParam.getAttr(key: "mode", attrs: opDesc.attrs) mode = try PreluParam.getAttr(key: "mode", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
......
...@@ -15,67 +15,67 @@ ...@@ -15,67 +15,67 @@
import Foundation import Foundation
struct OpDesc { struct OpDesc {
let inputs: [String : [String]] let inputs: [String : [String]]
var paraInputs: [String : [String]] var paraInputs: [String : [String]]
var outputs: [String : [String]] var outputs: [String : [String]]
let unusedOutputs: [String : [String]] let unusedOutputs: [String : [String]]
var attrs: [String : Attr] = [:] var attrs: [String : Attr] = [:]
var type: String var type: String
init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) { init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) {
type = protoOpDesc.type type = protoOpDesc.type
let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in
var map: [String : [String]] = [:] var map: [String : [String]] = [:]
for opDescVar in vars { for opDescVar in vars {
if (canAdd(opDescVar.parameter)) { if (canAdd(opDescVar.parameter)) {
map[opDescVar.parameter] = opDescVar.arguments map[opDescVar.parameter] = opDescVar.arguments
}
}
return map
}
inputs = creator(protoOpDesc.inputs) {
opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false
}
paraInputs = creator(protoOpDesc.inputs) {
!(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false)
}
outputs = creator(protoOpDesc.outputs) {
opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false
}
unusedOutputs = creator(protoOpDesc.outputs) {
!(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false)
}
for attr in protoOpDesc.attrs {
if (attr.type != .block) {
attrs[attr.name] = attrWithProtoDesc(attrDesc: attr)
}
} }
}
return map
} }
}
inputs = creator(protoOpDesc.inputs) {
extension OpDesc: CustomStringConvertible, CustomDebugStringConvertible { opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false
var description: String {
var str = ""
str += "op type: \(type): \n"
str += " op inputs: \n"
str += " \(inputs) \n"
str += " op para inputs: \n"
str += " \(paraInputs) \n"
str += " op para outputs: \n"
str += " \(outputs) \n"
str += " op attrs: \n"
str += " \(attrs) \n"
return str
} }
var debugDescription: String { paraInputs = creator(protoOpDesc.inputs) {
return description !(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false)
} }
outputs = creator(protoOpDesc.outputs) {
opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false
}
unusedOutputs = creator(protoOpDesc.outputs) {
!(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false)
}
for attr in protoOpDesc.attrs {
if (attr.type != .block) {
attrs[attr.name] = attrWithProtoDesc(attrDesc: attr)
}
}
}
}
extension OpDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String {
var str = ""
str += "op type: \(type): \n"
str += " op inputs: \n"
str += " \(inputs) \n"
str += " op para inputs: \n"
str += " \(paraInputs) \n"
str += " op para outputs: \n"
str += " \(outputs) \n"
str += " op attrs: \n"
str += " \(attrs) \n"
return str
}
var debugDescription: String {
return description
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册