diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj index 17af3762b15a622b046d0cc65eeca70b76b28705..9d5325c574019655d9c064bed297b12f143a2f63 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj @@ -18,8 +18,9 @@ FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC27991221343A3A000B6BAD /* CPUCompute.mm */; }; FC3C800F2133F46600D1295E /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C800E2133F46600D1295E /* MobileNetSSD.swift */; }; FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3C80102133F4AB00D1295E /* MobileNet.swift */; }; - FC8CFEDB21351F5D0094D569 /* params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFED921351F5D0094D569 /* params */; }; - FC8CFEDC21351F5D0094D569 /* model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDA21351F5D0094D569 /* model */; }; + FC8CFEDF213521C10094D569 /* genet_model in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDD213521C10094D569 /* genet_model */; }; + FC8CFEE0213521C10094D569 /* genet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC8CFEDE213521C10094D569 /* genet_params */; }; + FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC8CFEE1213524EA0094D569 /* Genet.swift */; }; FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */ = {isa = PBXBuildFile; fileRef = FC918190211DBC3500B6F354 /* paddle-mobile.png */; }; FC918193211DC70500B6F354 /* iphone.JPG in Resources */ = {isa = PBXBuildFile; fileRef = FC918192211DC70500B6F354 /* iphone.JPG */; }; FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FCA3A16021313E1F00084FE5 /* hand.jpg */; }; @@ -66,8 +67,9 @@ FC27991421343A46000B6BAD /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = ""; }; FC3C800E2133F46600D1295E /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = ""; }; FC3C80102133F4AB00D1295E /* MobileNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = ""; }; - FC8CFED921351F5D0094D569 /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = ""; }; - FC8CFEDA21351F5D0094D569 /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = ""; }; + FC8CFEDD213521C10094D569 /* genet_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_model; sourceTree = ""; }; + FC8CFEDE213521C10094D569 /* genet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = genet_params; sourceTree = ""; }; + FC8CFEE1213524EA0094D569 /* Genet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = ""; }; FC918190211DBC3500B6F354 /* paddle-mobile.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "paddle-mobile.png"; sourceTree = ""; }; FC918192211DC70500B6F354 /* iphone.JPG */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = iphone.JPG; sourceTree = ""; }; FCA3A16021313E1F00084FE5 /* hand.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = hand.jpg; sourceTree = ""; }; @@ -137,7 +139,6 @@ FC0E2C2020EDC03B009C1FAC /* models */, FC0E2C1D20EDC030009C1FAC /* images */, FC039B8120E11C550081E9F8 /* AppDelegate.swift */, - FC013927210204A3008100E3 /* PreProcessKernel.metal */, FC039B8320E11C550081E9F8 /* ViewController.swift */, FC039B8520E11C550081E9F8 /* Main.storyboard */, FC039B8820E11C560081E9F8 /* Assets.xcassets */, @@ -164,7 +165,7 @@ FC0E2C2020EDC03B009C1FAC /* models */ = { isa = PBXGroup; children = ( - FC8CFED821351F5D0094D569 /* enet */, + FC8CFED821351F5D0094D569 /* genet */, FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */, FCD04E6020F3146A0007374F /* mobilenet */, ); @@ -175,23 +176,25 @@ FC8CFED2213519540094D569 /* Net */ = { isa = PBXGroup; children = ( + FC013927210204A3008100E3 /* PreProcessKernel.metal */, FCBCCC542122EF5400D94F7E /* MetalHelper.swift */, FC3C800E2133F46600D1295E /* MobileNetSSD.swift */, FC3C80102133F4AB00D1295E /* MobileNet.swift */, FC27990F21341CE5000B6BAD /* Net.swift */, FC27991221343A3A000B6BAD /* CPUCompute.mm */, FC27991421343A46000B6BAD /* CPUCompute.h */, + FC8CFEE1213524EA0094D569 /* Genet.swift */, ); path = Net; sourceTree = ""; }; - FC8CFED821351F5D0094D569 /* enet */ = { + FC8CFED821351F5D0094D569 /* genet */ = { isa = PBXGroup; children = ( - FC8CFED921351F5D0094D569 /* params */, - FC8CFEDA21351F5D0094D569 /* model */, + FC8CFEDD213521C10094D569 /* genet_model */, + FC8CFEDE213521C10094D569 /* genet_params */, ); - path = enet; + path = genet; sourceTree = ""; }; FCBCCC4F2122EEDC00D94F7E /* mobilenet_ssd_hand */ = { @@ -274,10 +277,9 @@ isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( - FC8CFEDB21351F5D0094D569 /* params in Resources */, FCD04E6320F3146B0007374F /* params in Resources */, - FC8CFEDC21351F5D0094D569 /* model in Resources */, FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */, + FC8CFEE0213521C10094D569 /* genet_params in Resources */, FC918191211DBC3500B6F354 /* paddle-mobile.png in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, FCBCCC522122EEDC00D94F7E /* ssd_hand_params in Resources */, @@ -285,6 +287,7 @@ FC918193211DC70500B6F354 /* iphone.JPG in Resources */, FCDFD41B211D91C7005AB38B /* synset.txt in Resources */, FCD04E6420F3146B0007374F /* model in Resources */, + FC8CFEDF213521C10094D569 /* genet_model in Resources */, FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */, FCA3A16121313E1F00084FE5 /* hand.jpg in Resources */, FCBCCC532122EEDC00D94F7E /* ssd_hand_model in Resources */, @@ -340,6 +343,7 @@ FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */, FC27991021341CE5000B6BAD /* Net.swift in Sources */, + FC8CFEE2213524EA0094D569 /* Genet.swift in Sources */, FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */, FC27991321343A3A000B6BAD /* CPUCompute.mm in Sources */, FC3C80112133F4AB00D1295E /* MobileNet.swift in Sources */, @@ -490,19 +494,19 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CLANG_ENABLE_MODULES = YES; - CODE_SIGN_IDENTITY = "iPhone Distribution"; - CODE_SIGN_STYLE = Manual; - DEVELOPMENT_TEAM = 6T9LLJKSM4; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = A798K58VVL; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", ); - PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; + PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile"; PRODUCT_NAME = "$(TARGET_NAME)"; - PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; - PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; + PROVISIONING_PROFILE = ""; + PROVISIONING_PROFILE_SPECIFIER = ""; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_VERSION = 4.0; @@ -516,19 +520,19 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CLANG_ENABLE_MODULES = YES; - CODE_SIGN_IDENTITY = "iPhone Distribution"; - CODE_SIGN_STYLE = Manual; - DEVELOPMENT_TEAM = 6T9LLJKSM4; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = A798K58VVL; INFOPLIST_FILE = "paddle-mobile-demo/Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", ); - PRODUCT_BUNDLE_IDENTIFIER = com.baidu.mms.qa; + PRODUCT_BUNDLE_IDENTIFIER = "com.baidu.paddle-mobile"; PRODUCT_NAME = "$(TARGET_NAME)"; - PROVISIONING_PROFILE = "ba9c4b24-7bd0-49c5-93cd-e3051e775d6c"; - PROVISIONING_PROFILE_SPECIFIER = Distribution_MMS; + PROVISIONING_PROFILE = ""; + PROVISIONING_PROFILE_SPECIFIER = ""; SWIFT_OBJC_BRIDGING_HEADER = "paddle-mobile-demo/paddle-mobile-demo-Bridging-Header.h"; SWIFT_VERSION = 4.0; TARGETED_DEVICE_FAMILY = "1,2"; diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/paddle-mobile-demo.xcscheme b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/paddle-mobile-demo.xcscheme index de579675e07b8ffb8869464bebe203b591fa7778..46c65bd36a9ab7027b1cb7a81533dcd553ccb62e 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/paddle-mobile-demo.xcscheme +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/paddle-mobile-demo.xcscheme @@ -42,7 +42,7 @@ ? + + let except: Int = 0 + + class GenetPreProccess: CusomKernel { + init(device: MTLDevice) { + let s = CusomKernel.Shape.init(inWidth: 128, inHeight: 128, inChannel: 3) + super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, usePaddleMobileLib: false) + } + } + + func resultStr(res: [Float]) -> String { + return " 哈哈 还没好 genet !"; + } + + var preprocessKernel: CusomKernel + let dim = [1, 128, 128, 3] + let modelPath: String + let paramPath: String + let modelDir: String + + init() { + modelPath = Bundle.main.path(forResource: "genet_model", ofType: nil) ?! "model null" + paramPath = Bundle.main.path(forResource: "genet_params", ofType: nil) ?! "para null" + modelDir = "" + preprocessKernel = GenetPreProccess.init(device: MetalHelper.shared.device) + } + +} diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNet.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNet.swift index 5b2faac0d6056f29e062f854a21b6e759261118d..4c7dba391f35267893d12959bd86e0b859985a89 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNet.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNet.swift @@ -26,7 +26,7 @@ class MobileNet: Net{ class MobilenetPreProccess: CusomKernel { init(device: MTLDevice) { let s = CusomKernel.Shape.init(inWidth: 224, inHeight: 224, inChannel: 3) - super.init(device: device, inFunctionName: "preprocess", outputDim: s, usePaddleMobileLib: false) + super.init(device: device, inFunctionName: "mobilenet_preprocess", outputDim: s, usePaddleMobileLib: false) } } diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNetSSD.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNetSSD.swift index 10805a2b3ba468dc4b48bd30cb837fd1ead0f636..775c5b0caa9462908040be0b4caa80285b0fbafd 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNetSSD.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/MobileNetSSD.swift @@ -69,7 +69,6 @@ class MobileNet_ssd_hand: Net{ let output: [Float32] = result.map { $0.floatValue } - return output } diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/Net.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/Net.swift index 6aada07d5c7e97dd6d8e3470d59babd38b0171f5..eb22e462da5af851e1e6b8b8004cdc57bcbb31db 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/Net/Net.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/Net.swift @@ -21,18 +21,6 @@ import paddle_mobile import MetalPerformanceShaders -let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.init()] -//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()] - -enum SupportModel: String{ -// case mobilenet = "mobilenet" - case mobilenet_ssd = "mobilenetssd" - static func supportedModels() -> [SupportModel] { - //.mobilenet, - return [.mobilenet_ssd] - } -} - protocol Net { var program: Program? { get set } var executor: Executor? { get set } diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/PreProcessKernel.metal similarity index 78% rename from metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal rename to metal/paddle-mobile-demo/paddle-mobile-demo/Net/PreProcessKernel.metal index ab43f8d9233c581cd6957befb2decafc557f1264..bd0f84cdad16fdb61a525f5653d53b2aeca0d1aa 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/Net/PreProcessKernel.metal @@ -10,7 +10,7 @@ using namespace metal; -kernel void preprocess( +kernel void mobilenet_preprocess( texture2d inTexture [[texture(0)]], texture2d outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) @@ -24,7 +24,7 @@ kernel void preprocess( outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid); } -kernel void preprocess_half( +kernel void mobilenet_preprocess_half( texture2d inTexture [[texture(0)]], texture2d outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) @@ -68,5 +68,17 @@ kernel void mobilenet_ssd_preprocess_half( } - +kernel void genet_preprocess( + texture2d inTexture [[texture(0)]], + texture2d outTexture [[texture(1)]], + uint2 gid [[thread_position_in_grid]]) +{ + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height()) { + return; + } + const auto means = float4(123.68f, 116.78f, 103.94f, 0.0f); + const float4 inColor = (inTexture.read(gid) * 255.0 - means) * 0.017; + outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid); +} diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift index d910905bc544f50a149386ce3191a77f311ffa13..81eefae40ca1d16546a06f7ce74f27be1e6d7d91 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift @@ -19,6 +19,20 @@ import MetalPerformanceShaders let threadSupport = [1] +let modelHelperMap: [SupportModel : Net] = [.mobilenet_ssd : MobileNet_ssd_hand.init(), .genet : Genet.init()] +//let modelHelperMap: [SupportModel : Net] = [.mobilenet : MobileNet.init(), .mobilenet_ssd : MobileNet_ssd_hand.init()] + +enum SupportModel: String{ + // case mobilenet = "mobilenet" + case mobilenet_ssd = "mobilenetssd" + case genet = "enet" + static func supportedModels() -> [SupportModel] { + //.mobilenet, + return [.mobilenet_ssd, .genet] + } +} + + class ViewController: UIViewController { @IBOutlet weak var resultTextView: UITextView! @IBOutlet weak var selectImageView: UIImageView! diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 5c6e2c299c2aa5a5185019f4b64be272214286c1..91aff13bd5769d302f42712141501638a8dcd8c3 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -130,11 +130,6 @@ public class Executor { // } // return -// self.ops[testTo].delogOutput() -// self.ops[91].delogOutput() -// self.ops[92].delogOutput() -// self.ops[93].delogOutput() - let afterDate = Date.init() var resultHolder: ResultHolder

diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift index 248a0162e4425ed13efcee46d4177a2600534524..1e92c342c2b118ae08116b2d75b52c67e19328e6 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift @@ -58,7 +58,9 @@ class OpCreator { gDwConvBnReluType : DwConvBNReluOp

.creat, gMulticlassNMSType : MulticlassNMSOp

.creat, gTransposeType : TransposeOp

.creat, - gPriorBoxType : PriorBoxOp

.creat] - + gPriorBoxType : PriorBoxOp

.creat, + gPreluType : PreluOp

.creat, + gConv2dTransposeType : ConvTransposeOp

.creat] + private init(){} } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift index b972838a53006a7b3a5d9f33e7105dc3aa95e31a..9f868e35864d59be5711c4ac0a02787638eeae8f 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift @@ -50,7 +50,7 @@ protocol OpParam { static func getAttr(key: String, attrs: [String : Attr]) throws -> T - static func inputAlpha(inputs: [String : [String]], from: Scope) throws -> VarType + static func paramInputAlpha(inputs: [String : [String]], from: Scope) throws -> VarType } @@ -63,9 +63,14 @@ extension OpParam { guard let mapKeys = map[key], mapKeys.count > 0 else { throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty") } - guard let variant = from[mapKeys[0]], let v = variant as? VarType else { + guard let variant = from[mapKeys[0]] else { throw PaddleMobileError.paramError(message: mapKeys[0] + " not found in scope") } + + guard let v = variant as? VarType else { + throw PaddleMobileError.paramError(message: " type error") + + } return v } @@ -78,7 +83,7 @@ extension OpParam { } } - static func inputAlpha(inputs: [String : [String]], from: Scope) throws -> VarType { + static func paramInputAlpha(inputs: [String : [String]], from: Scope) throws -> VarType { do { let alphaTensor: VarType = try getFirstTensor(key: "Alpha", map: inputs, from: from) return alphaTensor diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index 424d55a7702cb2d8b04a62e2a93bb5ca11a1732e..383a84726f4798a1223262a76f0b666ab7928ff4 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -137,6 +137,9 @@ let gBoxcoderType = "box_coder" let gMulticlassNMSType = "multiclass_nms" let gConvBnReluType = "conv_bn_relu" let gDwConvBnReluType = "depth_conv_bn_relu" +let gPreluType = "prelu" +let gConv2dTransposeType = "conv2d_transpose" + let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), gBatchNormType : (inputs: ["X"], outputs: ["Y"]), @@ -156,4 +159,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out gConvBnReluType : (inputs: ["Input"], outputs: ["Out"]), gDwConvBnReluType : (inputs: ["Input"], outputs: ["Out"]), gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]), - gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"])] + gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]), + gPreluType : (inputs: ["X"], outputs: ["Out"]), + gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]) + ] diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift index a5b1312e7cfc4982b06e0ee7912bba158c0e4d54..6f67014444e5ef82fe4cdc30f99bc371fef2d417 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift @@ -34,7 +34,7 @@ class ConvAddBatchNormReluParam: OpParam { scale = try ConvAddBatchNormReluParam.inputScale(inputs: opDesc.paraInputs, from: inScope) mean = try ConvAddBatchNormReluParam.inputMean(inputs: opDesc.paraInputs, from: inScope) - y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.paraInputs, from: inScope) + y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.inputs, from: inScope) } catch let error { throw error } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift index 960eb4667bee4e3a0c9ba57737728b46a4505836..4a1bad4ef17c9b1051500290438eb3f5fffc23ef 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift @@ -25,6 +25,7 @@ class ConvAddParam: OpParam { paddings = try ConvAddParam.getAttr(key: "paddings", attrs: opDesc.attrs) dilations = try ConvAddParam.getAttr(key: "dilations", attrs: opDesc.attrs) groups = try ConvAddParam.getAttr(key: "groups", attrs: opDesc.attrs) + y = try ConvAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) } catch let error { throw error diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift index aeb1bea230a25623fce7c73f5025e166f501a7c6..bf17a9ee5faec80701a0e2efb33ad4765aaaf1df 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift @@ -18,17 +18,27 @@ class ElementwiseAddParam: OpParam { typealias ParamPrecisionType = P required init(opDesc: OpDesc, inScope: Scope) throws { do { - input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) - + } catch _ { + do { + inputYTexture = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope) + } catch let error { + throw error + } + } + do { + input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) } catch let error { throw error } } - let input: Texture

- let inputY: Tensor

+ + var inputYTexture: Texture

? + var inputY: Tensor

? + var input: Texture

+ var output: Texture

let axis: Int } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift index 361e77950841f2fa2b54884a2fbf394714f10902..2050c38c3477917ba2d568504665593f672057d0 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift @@ -16,11 +16,12 @@ import Foundation class ElementwiseAddKernel: Kernel, Computable { - required init(device: MTLDevice, param: ElementwiseAddParam

) { - super.init(device: device, inFunctionName: "elementwise_add") - } + required init(device: MTLDevice, param: ElementwiseAddParam

) { + super.init(device: device, inFunctionName: "elementwise_add") + param.output.initTexture(device: device, inTranspose: param.input.transpose) + } + + func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam

) throws { - func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam

) throws { - - } + } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PreluKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PreluKernel.swift index 1b54fdde38cb6e98b8edeb454bf40f2f141e560c..d1d82aeb539e9c433c579e797c187c682b2ac235 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PreluKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PreluKernel.swift @@ -17,6 +17,7 @@ class PreluKernel: Kernel, Computable{ } else { super.init(device: device, inFunctionName: "prelu_other") } + param.output.initTexture(device: device, inTranspose: param.input.transpose) } func compute(commandBuffer: MTLCommandBuffer, param: PreluParam

) throws { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/PreluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/PreluOp.swift index dff120226b6d65ceae32120b1149f21fd45456f2..44b509eb66be4ada2a5e46af73d1a97011cd9b85 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/PreluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/PreluOp.swift @@ -21,7 +21,7 @@ class PreluParam: OpParam { do { input = try PreluParam.inputX(inputs: opDesc.inputs, from: inScope) output = try PreluParam.outputOut(outputs: opDesc.outputs, from: inScope) - alpha = try PreluParam.inputAlpha(inputs: opDesc.inputs, from: inScope) + alpha = try PreluParam.paramInputAlpha(inputs: opDesc.paraInputs, from: inScope) mode = try PreluParam.getAttr(key: "mode", attrs: opDesc.attrs) } catch let error { throw error diff --git a/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift b/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift index 73f81152316ad6812f705979b9c2358ee03eb3c8..45f5d529503c7e985917a5e789b02b0bdbfc767e 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift @@ -15,67 +15,67 @@ import Foundation struct OpDesc { - let inputs: [String : [String]] - var paraInputs: [String : [String]] - var outputs: [String : [String]] - let unusedOutputs: [String : [String]] - var attrs: [String : Attr] = [:] - var type: String - init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) { - type = protoOpDesc.type - let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in - var map: [String : [String]] = [:] - for opDescVar in vars { - if (canAdd(opDescVar.parameter)) { - map[opDescVar.parameter] = opDescVar.arguments - } - } - return map - } - - inputs = creator(protoOpDesc.inputs) { - opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false - } - - paraInputs = creator(protoOpDesc.inputs) { - !(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false) - } - - outputs = creator(protoOpDesc.outputs) { - opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false - } - - unusedOutputs = creator(protoOpDesc.outputs) { - !(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false) - } - - for attr in protoOpDesc.attrs { - if (attr.type != .block) { - attrs[attr.name] = attrWithProtoDesc(attrDesc: attr) - } + let inputs: [String : [String]] + var paraInputs: [String : [String]] + var outputs: [String : [String]] + let unusedOutputs: [String : [String]] + var attrs: [String : Attr] = [:] + var type: String + init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) { + type = protoOpDesc.type + let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in + var map: [String : [String]] = [:] + for opDescVar in vars { + if (canAdd(opDescVar.parameter)) { + map[opDescVar.parameter] = opDescVar.arguments } + } + return map } -} - -extension OpDesc: CustomStringConvertible, CustomDebugStringConvertible { - var description: String { - var str = "" - str += "op type: \(type): \n" - str += " op inputs: \n" - str += " \(inputs) \n" - str += " op para inputs: \n" - str += " \(paraInputs) \n" - str += " op para outputs: \n" - str += " \(outputs) \n" - str += " op attrs: \n" - str += " \(attrs) \n" - - return str + + inputs = creator(protoOpDesc.inputs) { + opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false } - var debugDescription: String { - return description + paraInputs = creator(protoOpDesc.inputs) { + !(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false) } + outputs = creator(protoOpDesc.outputs) { + opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false + } + + unusedOutputs = creator(protoOpDesc.outputs) { + !(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false) + } + + for attr in protoOpDesc.attrs { + if (attr.type != .block) { + attrs[attr.name] = attrWithProtoDesc(attrDesc: attr) + } + } + } +} + +extension OpDesc: CustomStringConvertible, CustomDebugStringConvertible { + var description: String { + var str = "" + str += "op type: \(type): \n" + str += " op inputs: \n" + str += " \(inputs) \n" + str += " op para inputs: \n" + str += " \(paraInputs) \n" + str += " op para outputs: \n" + str += " \(outputs) \n" + str += " op attrs: \n" + str += " \(attrs) \n" + return str + } + + var debugDescription: String { + return description + } + + }