From d1d1e9329c4cb11a121757103c9bb0173cf33268 Mon Sep 17 00:00:00 2001 From: liuruilong Date: Tue, 10 Jul 2018 12:00:31 +0800 Subject: [PATCH] add conv add batch norm relu metal --- .../project.pbxproj | 30 +++---- .../paddle-mobile.xcodeproj/project.pbxproj | 32 +++++++ .../paddle-mobile/Common/Extensions.swift | 16 ++++ .../paddle-mobile/Common/MetalExtension.swift | 3 +- .../paddle-mobile/Common/Types.swift | 59 +++++++++++-- .../paddle-mobile/Executor.swift | 12 ++- .../paddle-mobile/paddle-mobile/Loader.swift | 7 +- .../Operators/Base/OpCreator.swift | 18 ++-- .../Operators/Base/Operator.swift | 24 ++--- .../paddle-mobile/Operators/BatchNormOp.swift | 4 +- .../Operators/ConvAddBatchNormReluOp.swift | 68 ++++++++++++--- .../paddle-mobile/Operators/ConvAddOp.swift | 87 +++++++++++++++++++ .../paddle-mobile/Operators/ConvOp.swift | 7 +- .../Operators/ElementwiseAddOp.swift | 7 +- .../paddle-mobile/Operators/FeedOp.swift | 4 +- .../paddle-mobile/Operators/FetchOp.swift | 4 +- .../Kernels/ConvAddBatchNormReluKernel.swift | 54 +++++++++++- .../Operators/Kernels/ConvAddKernel.swift | 19 ++++ .../Operators/Kernels/ConvKernel.metal | 45 ++++++++++ .../Operators/Kernels/PoolKernel.swift | 25 ++++++ .../Operators/Kernels/ReshapeKernel.swift | 26 ++++++ .../Operators/Kernels/SoftmaxKernel.swift | 25 ++++++ .../paddle-mobile/Operators/PoolOp.swift | 39 +++++++++ .../paddle-mobile/Operators/ReluOp.swift | 4 +- .../paddle-mobile/Operators/ReshapeOp.swift | 39 +++++++++ .../paddle-mobile/Operators/SoftmaxOp.swift | 39 +++++++++ .../paddle-mobile/Program/OpDesc.swift | 2 +- .../paddle-mobile/Program/Program.swift | 4 +- .../Program/ProgramOptimize.swift | 48 ++++++---- 29 files changed, 650 insertions(+), 101 deletions(-) create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj index d71ac8605a..1089a98fc5 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj @@ -214,10 +214,8 @@ FC0E2DB420EDC03C009C1FAC /* conv2d_27.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */; }; FC0E2DB520EDC03C009C1FAC /* conv2d_33.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */; }; FC0E2DB620EDC03C009C1FAC /* depthwise_conv2d_7.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */; }; - FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F820F227C60099DBAF /* mobilenet */; }; - FCEBC0FD20F227C60099DBAF /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F920F227C60099DBAF /* params */; }; - FCEBC0FE20F227C60099DBAF /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FA20F227C60099DBAF /* model */; }; - FCEBC0FF20F227C60099DBAF /* yolo in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FB20F227C60099DBAF /* yolo */; }; + FCD04E6320F3146B0007374F /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6120F3146A0007374F /* params */; }; + FCD04E6420F3146B0007374F /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6220F3146A0007374F /* model */; }; FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; }; FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; /* End PBXBuildFile section */ @@ -448,10 +446,8 @@ FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_27.w_0; sourceTree = ""; }; FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_33.w_0; sourceTree = ""; }; FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = depthwise_conv2d_7.w_0; sourceTree = ""; }; - FCEBC0F820F227C60099DBAF /* mobilenet */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = mobilenet; sourceTree = ""; }; - FCEBC0F920F227C60099DBAF /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = ""; }; - FCEBC0FA20F227C60099DBAF /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = ""; }; - FCEBC0FB20F227C60099DBAF /* yolo */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = yolo; sourceTree = ""; }; + FCD04E6120F3146A0007374F /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = ""; }; + FCD04E6220F3146A0007374F /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = ""; }; FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; /* End PBXFileReference section */ @@ -531,7 +527,7 @@ FC0E2C2020EDC03B009C1FAC /* models */ = { isa = PBXGroup; children = ( - FCEBC0F720F227C60099DBAF /* yolo */, + FCD04E6020F3146A0007374F /* mobilenet */, FC0E2C2420EDC03B009C1FAC /* mobilenetssd */, ); name = models; @@ -745,15 +741,13 @@ path = mobilenetssd; sourceTree = ""; }; - FCEBC0F720F227C60099DBAF /* yolo */ = { + FCD04E6020F3146A0007374F /* mobilenet */ = { isa = PBXGroup; children = ( - FCEBC0F820F227C60099DBAF /* mobilenet */, - FCEBC0F920F227C60099DBAF /* params */, - FCEBC0FA20F227C60099DBAF /* model */, - FCEBC0FB20F227C60099DBAF /* yolo */, + FCD04E6120F3146A0007374F /* params */, + FCD04E6220F3146A0007374F /* model */, ); - path = yolo; + path = mobilenet; sourceTree = ""; }; /* End PBXGroup section */ @@ -828,7 +822,6 @@ FC0E2D2920EDC03B009C1FAC /* batch_norm_2.b_0 in Resources */, FC0E2DA920EDC03C009C1FAC /* conv2d_26.w_0 in Resources */, FC0E2D0420EDC03B009C1FAC /* batch_norm_16.w_2 in Resources */, - FCEBC0FE20F227C60099DBAF /* model in Resources */, FC0E2D0720EDC03B009C1FAC /* batch_norm_6.w_1 in Resources */, FC0E2DB020EDC03C009C1FAC /* batch_norm_30.w_2 in Resources */, FC0E2D9720EDC03C009C1FAC /* conv2d_25.w_0 in Resources */, @@ -844,11 +837,9 @@ FC0E2DA620EDC03C009C1FAC /* depthwise_conv2d_4.w_0 in Resources */, FC0E2D6920EDC03C009C1FAC /* conv2d_6.w_0 in Resources */, FC0E2D6520EDC03C009C1FAC /* conv2d_7.w_0 in Resources */, - FCEBC0FD20F227C60099DBAF /* params in Resources */, FC0E2DAB20EDC03C009C1FAC /* batch_norm_19.w_2 in Resources */, FC0E2D9920EDC03C009C1FAC /* conv2d_31.w_0 in Resources */, FC0E2D3020EDC03B009C1FAC /* batch_norm_34.w_0 in Resources */, - FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */, FC0E2D1220EDC03B009C1FAC /* batch_norm_34.b_0 in Resources */, FC0E2D4D20EDC03C009C1FAC /* batch_norm_7.b_0 in Resources */, FC0E2D2520EDC03B009C1FAC /* batch_norm_21.w_1 in Resources */, @@ -857,6 +848,7 @@ FC0E2D8620EDC03C009C1FAC /* conv2d_23.w_0 in Resources */, FC0E2CFE20EDC03B009C1FAC /* depthwise_conv2d_9.w_0 in Resources */, FC0E2D4C20EDC03C009C1FAC /* batch_norm_8.w_2 in Resources */, + FCD04E6320F3146B0007374F /* params in Resources */, FC0E2D5820EDC03C009C1FAC /* conv2d_5.w_0 in Resources */, FC0E2D1620EDC03B009C1FAC /* batch_norm_3.w_1 in Resources */, FC0E2DB120EDC03C009C1FAC /* batch_norm_24.w_2 in Resources */, @@ -949,7 +941,6 @@ FC0E2D0F20EDC03B009C1FAC /* batch_norm_5.w_0 in Resources */, FC0E2D4520EDC03C009C1FAC /* batch_norm_9.w_2 in Resources */, FC0E2D9020EDC03C009C1FAC /* batch_norm_23.w_2 in Resources */, - FCEBC0FF20F227C60099DBAF /* yolo in Resources */, FC0E2D6720EDC03C009C1FAC /* conv2d_31.b_0 in Resources */, FC0E2DA020EDC03C009C1FAC /* conv2d_18.w_0 in Resources */, FC0E2D1C20EDC03B009C1FAC /* conv2d_13.w_0 in Resources */, @@ -972,6 +963,7 @@ FC0E2D0920EDC03B009C1FAC /* conv2d_14.w_0 in Resources */, FC0E2CF720EDC03B009C1FAC /* batch_norm_28.w_2 in Resources */, FC0E2D9520EDC03C009C1FAC /* depthwise_conv2d_5.w_0 in Resources */, + FCD04E6420F3146B0007374F /* model in Resources */, FC0E2D4A20EDC03C009C1FAC /* conv2d_9.w_0 in Resources */, FC0E2D4E20EDC03C009C1FAC /* batch_norm_19.w_1 in Resources */, FC0E2D3620EDC03C009C1FAC /* batch_norm_18.w_0 in Resources */, diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index e64d9a67e6..1eed321f54 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -45,6 +45,14 @@ FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; }; FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; }; FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; }; + FCD04E6620F314C50007374F /* PoolOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6520F314C50007374F /* PoolOp.swift */; }; + FCD04E6820F315020007374F /* PoolKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6720F315020007374F /* PoolKernel.swift */; }; + FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6920F319EC0007374F /* SoftmaxOp.swift */; }; + FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */; }; + FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */; }; + FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6F20F31B720007374F /* ReshapeKernel.swift */; }; + FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7120F343420007374F /* ConvAddOp.swift */; }; + FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7320F3437E0007374F /* ConvAddKernel.swift */; }; FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */; }; FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; }; FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; }; @@ -93,6 +101,14 @@ FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = ""; }; FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = ""; }; FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = ""; }; + FCD04E6520F314C50007374F /* PoolOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PoolOp.swift; sourceTree = ""; }; + FCD04E6720F315020007374F /* PoolKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PoolKernel.swift; sourceTree = ""; }; + FCD04E6920F319EC0007374F /* SoftmaxOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SoftmaxOp.swift; sourceTree = ""; }; + FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SoftmaxKernel.swift; sourceTree = ""; }; + FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReshapeOp.swift; sourceTree = ""; }; + FCD04E6F20F31B720007374F /* ReshapeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReshapeKernel.swift; sourceTree = ""; }; + FCD04E7120F343420007374F /* ConvAddOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddOp.swift; sourceTree = ""; }; + FCD04E7320F3437E0007374F /* ConvAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddKernel.swift; sourceTree = ""; }; FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = ConvAddBatchNormReluOp.swift; path = "paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"; sourceTree = SOURCE_ROOT; }; FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = ""; }; FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; }; @@ -193,6 +209,10 @@ FC039BA820E11CBC0081E9F8 /* ReluOp.swift */, FC9D037F20E22FBB000F735A /* FeedOp.swift */, FC9D038120E2312E000F735A /* FetchOp.swift */, + FCD04E6520F314C50007374F /* PoolOp.swift */, + FCD04E6920F319EC0007374F /* SoftmaxOp.swift */, + FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */, + FCD04E7120F343420007374F /* ConvAddOp.swift */, ); path = Operators; sourceTree = ""; @@ -227,6 +247,10 @@ FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */, + FCD04E6720F315020007374F /* PoolKernel.swift */, + FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */, + FCD04E6F20F31B720007374F /* ReshapeKernel.swift */, + FCD04E7320F3437E0007374F /* ConvAddKernel.swift */, ); path = Kernels; sourceTree = ""; @@ -346,6 +370,8 @@ FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, + FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, + FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */, FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */, @@ -362,21 +388,27 @@ FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */, + FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */, FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */, + FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, + FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, + FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */, + FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */, + FCD04E6820F315020007374F /* PoolKernel.swift in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift b/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift index c33a7e7fa9..946af08b93 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift @@ -72,6 +72,16 @@ extension Array: CIntIndex{ } } +extension Array where Element: AnyObject{ + mutating func remove(element: Element) { + if let index = index(where: { (node) -> Bool in + return unsafeBitCast(element, to: Int.self) == unsafeBitCast(node, to: Int.self) + }) { + remove(at: index) + } + } + +} //MARK: Array extension extension Array where Element: Comparable{ @@ -92,4 +102,10 @@ extension String{ } } +func address(o: T) -> String { + return String.init(format: "%018p", unsafeBitCast(o, to: Int.self)) +} + + + diff --git a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift index dbd015c3eb..16eda20526 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift @@ -42,7 +42,6 @@ extension MTLDevice { } } - func pipeLine(funcName: String, inPaddleMobileLib: Bool = true) -> MTLComputePipelineState { let useLib = inPaddleMobileLib ? paddleMobileLibrary() : defaultLibrary() guard let function = useLib.makeFunction(name: funcName) else { @@ -65,7 +64,7 @@ extension MTLComputeCommandEncoder { let width = computePipline.threadExecutionWidth let height = computePipline.maxTotalThreadsPerThreadgroup/width let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1) - + // print(" thread: threads per group: \(threadsPerGroup) ") // print(" thread: out texture width: \(outTexture.width) , out texture height: \(outTexture.height)") diff --git a/metal/paddle-mobile/paddle-mobile/Common/Types.swift b/metal/paddle-mobile/paddle-mobile/Common/Types.swift index f910d33628..98353617f5 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Types.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Types.swift @@ -14,22 +14,71 @@ import Foundation +public protocol SummableMultipliable: Equatable { + static func +(lhs: Self, rhs: Self) -> Self + static func *(lhs: Self, rhs: Self) -> Self + static func -(lhs: Self, rhs: Self) -> Self +} +public protocol PrecisionType: SummableMultipliable{ + init(inFloat: Float32) + init(inFloat16: Float16) + init(_ inP: P) + static var bitSize: UInt { get } +} + public typealias Float16 = Int16 extension Float16: PrecisionType { + public static func * (prefix: Float16, postfix: Float16) { + return prefix * postfix + } + + public init

(_ inP: P) where P : PrecisionType { + if P.bitSize == Float32.bitSize { + self = Float16(inFloat: inP as! Float32) + } else if P.bitSize == Float16.bitSize { + self = inP as! Float16 + } else { + fatalError() + } + } + + public static var bitSize: UInt { + return 16 + } + + public init(inFloat16: Float16) { + self = inFloat16 + } public init(inFloat: Float32) { self = Int16(inFloat) } + + + } -public protocol PrecisionType { - init(inFloat: Float32) -} - - extension Float32: PrecisionType { + public init

(_ inP: P) where P : PrecisionType { + if P.bitSize == Float32.bitSize { + self = inP as! Float32 + } else if P.bitSize == Float16.bitSize { + self = Float32.init(inP as! Float16) + } else { + fatalError() + } + } + public init(inFloat: Float32) { self = inFloat } + + public init(inFloat16: Float16) { + self = Float32.init(inFloat16) + } + + public static var bitSize: UInt { + return 32 + } } public enum DataLayout { diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 45188bc3c1..72fa3bda7f 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -55,7 +55,8 @@ public class Executor { device = inDevice queue = inQueue for block in inProgram.programDesc.blocks { - for op in block.ops { + for i in 0..<7 { + let op = block.ops[i] do { let op = try OpCreator

.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) op.inferShape() @@ -64,6 +65,15 @@ public class Executor { throw error } } +// for op in block.ops { +// do { +// let op = try OpCreator

.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) +// op.inferShape() +// ops.append(op) +// } catch let error { +// throw error +// } +// } } } diff --git a/metal/paddle-mobile/paddle-mobile/Loader.swift b/metal/paddle-mobile/paddle-mobile/Loader.swift index 449b3f9865..6e9af2930f 100644 --- a/metal/paddle-mobile/paddle-mobile/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Loader.swift @@ -104,12 +104,9 @@ public class Loader { serializedData: modelData) let originProgramDesc = ProgramDesc.init(protoProgram: protoProgram) - - let programDesc = ProgramOptimize

.init().optimize(originProgramDesc: originProgramDesc) print(programDesc) - - fatalError() + guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else { throw PaddleMobileError.loaderError(message: "load para error") } @@ -180,7 +177,7 @@ public class Loader { } } - let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope) + let program = Program.init(inProgramDesc: programDesc, inParamPath: paraPath, inScope: scope) return program } catch _ { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift index f99c9ff685..0ba02af1c5 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift @@ -40,13 +40,17 @@ class OpCreator { } let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] = - [gConvType : ConvOp

.creat, - gBatchNormType : BatchNormOp

.creat, - gReluType : ReluOp

.creat, - gElementwiseAdd : ElementwiseAddOp

.creat, - gFeedType : FeedOp

.creat, - gFetchType : FetchOp

.creat, - gConvAddBatchNormReluType : ConvAddBatchNormReluOp

.creat] + [gConvType : ConvOp

.creat, + gBatchNormType : BatchNormOp

.creat, + gReluType : ReluOp

.creat, + gElementwiseAdd : ElementwiseAddOp

.creat, + gFeedType : FeedOp

.creat, + gFetchType : FetchOp

.creat, + gConvAddBatchNormReluType : ConvAddBatchNormReluOp

.creat, + gPooType : PoolOp

.creat, + gSoftmaxType : SoftmaxOp

.creat, + gReshapeType : ReshapeOp

.creat, + gConvAddType : ConvAddOp

.creat] private init(){} } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index 1bfc184ed3..3ca41ed724 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -18,9 +18,9 @@ import Foundation protocol Fusion { static func fusionNode() -> Node static func change() -> [String : [(from: String, to: String)]] + static func fusionType() -> String } - protocol Runable { func run(device: MTLDevice, buffer: MTLCommandBuffer) throws func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws @@ -117,20 +117,20 @@ let gBatchNormType = "batch_norm" let gReluType = "relu" let gElementwiseAdd = "elementwise_add" let gConvAddBatchNormReluType = "conv_add_batchnorm_relu" +let gPooType = "pool2d" +let gSoftmaxType = "softmax" +let gReshapeType = "reshape" +let gConvAddType = "conv_add" + let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), gBatchNormType : (inputs: ["X"], outputs: ["Y"]), gReluType : (inputs: ["X"], outputs: ["Out"]), - gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]), + gElementwiseAdd : (inputs: ["X"], outputs: ["Out"]), gFeedType : (inputs: ["X"], outputs: ["Out"]), gFetchType : (inputs: ["X"], outputs: ["Out"]), - gConvAddBatchNormReluType : (inputs: ["Input"], outputs: ["Out"])] - - - - - - - - - + gConvAddBatchNormReluType : (inputs: ["Input"], outputs: ["Out"]), + gPooType : (inputs: ["X"], outputs: ["Out"]), + gSoftmaxType : (inputs: ["X"], outputs: ["Out"]), + gReshapeType : (inputs: ["X"], outputs: ["Out"]), + gConvAddType : (inputs: ["Input"], outputs: ["Out"])] diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 3fefe9f2fe..3b45d97c30 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -14,9 +14,9 @@ import Foundation -struct BatchNormParam: OpParam { +class BatchNormParam: OpParam { typealias ParamPrecisionType = P - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { do { input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope) output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift index d1f3fd60bb..5b887e82a5 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift @@ -8,21 +8,49 @@ import Foundation - -class ConvAddBatchNormReluOp: Operator, ConvParam

>, Runable, Creator, InferShaperable, Fusion{ - static func fusionNode() -> Node { - let beginNode = Node.init(inType: gConvType) - _ = beginNode - --> Node.init(inType: gElementwiseAdd) - --> Node.init(inType: gBatchNormType) - --> Node.init(inType: gReluType) - return beginNode +class ConvAddBatchNormReluParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + filter = try ConvAddBatchNormReluParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope) + input = try ConvAddBatchNormReluParam.input(inputs: opDesc.inputs, from: inScope) + output = try ConvAddBatchNormReluParam.output(outputs: opDesc.outputs, from: inScope) + stride = try ConvAddBatchNormReluParam.getAttr(key: "strides", attrs: opDesc.attrs) + paddings = try ConvAddBatchNormReluParam.getAttr(key: "paddings", attrs: opDesc.attrs) + dilations = try ConvAddBatchNormReluParam.getAttr(key: "dilations", attrs: opDesc.attrs) + epsilon = try ConvAddBatchNormReluParam.getAttr(key: "epsilon", attrs: opDesc.attrs) + + groups = try ConvAddBatchNormReluParam.getAttr(key: "groups", attrs: opDesc.attrs) + variance = try ConvAddBatchNormReluParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) + bias = try ConvAddBatchNormReluParam.inputBiase(inputs: opDesc.paraInputs, from: inScope) + scale = try ConvAddBatchNormReluParam.inputScale(inputs: opDesc.paraInputs, from: inScope) + mean = try ConvAddBatchNormReluParam.inputMean(inputs: opDesc.paraInputs, from: inScope) + y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.paraInputs, from: inScope) + } catch let error { + throw error + } } - static func change() -> [String : [(from: String, to: String)]] { - return [:] - } + let input: Texture

+ + let variance: Tensor + let bias: Tensor + let mean: Tensor + let scale: Tensor + let y: Tensor + let filter: Tensor + let epsilon: Float32 + var newScale: MTLBuffer? + var newBiase: MTLBuffer? + var output: Texture

+ let stride: [Int32] + let paddings: [Int32] + let dilations: [Int32] + let groups: Int +} + +class ConvAddBatchNormReluOp: Operator, ConvAddBatchNormReluParam

>, Runable, Creator, InferShaperable, Fusion{ typealias OpType = ConvAddBatchNormReluOp

func inferShape() { @@ -55,4 +83,20 @@ class ConvAddBatchNormReluOp: Operator Node { + let beginNode = Node.init(inType: gConvType) + _ = beginNode + --> Node.init(inType: gElementwiseAdd) + --> Node.init(inType: gBatchNormType) + --> Node.init(inType: gReluType) + return beginNode + } + + static func change() -> [String : [(from: String, to: String)]] { + return [:] + } + + static func fusionType() -> String { + return gConvAddBatchNormReluType + } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift new file mode 100644 index 0000000000..3feaded60f --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddOp.swift @@ -0,0 +1,87 @@ +// +// ConvAddBatchNormReluOp.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/8. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class ConvAddParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + filter = try ConvAddParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope) + input = try ConvAddParam.input(inputs: opDesc.inputs, from: inScope) + output = try ConvAddParam.output(outputs: opDesc.outputs, from: inScope) + stride = try ConvAddParam.getAttr(key: "strides", attrs: opDesc.attrs) + paddings = try ConvAddParam.getAttr(key: "paddings", attrs: opDesc.attrs) + dilations = try ConvAddParam.getAttr(key: "dilations", attrs: opDesc.attrs) + groups = try ConvAddParam.getAttr(key: "groups", attrs: opDesc.attrs) + y = try ConvAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) + } catch let error { + throw error + } + } + + let input: Texture

+ let y: Tensor + let filter: Tensor + + var output: Texture

+ let stride: [Int32] + let paddings: [Int32] + let dilations: [Int32] + let groups: Int +} + +class ConvAddOp: Operator, ConvAddParam

>, Runable, Creator, InferShaperable, Fusion{ + static func fusionNode() -> Node { + let beginNode = Node.init(inType: gConvType) + _ = beginNode + --> Node.init(inType: gElementwiseAdd) + return beginNode + } + + static func change() -> [String : [(from: String, to: String)]] { + return [:] + } + + static func fusionType() -> String { + return gConvAddType + } + + typealias OpType = ConvAddOp

+ + func inferShape() { + let inDims = para.input.dim + let filterDim = para.filter.dim + let strides = para.stride + let paddings = para.paddings + let dilations = para.dilations + + var outDim = [inDims[0]] + for i in 0..: OpParam { +class ConvParam: OpParam { typealias ParamPrecisionType = P - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { do { filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope) input = try ConvParam.input(inputs: opDesc.inputs, from: inScope) @@ -25,14 +25,15 @@ struct ConvParam: OpParam { paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs) dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs) groups = try ConvParam.getAttr(key: "groups", attrs: opDesc.attrs) + } catch let error { throw error } } let input: Texture

- var output: Texture

let filter: Tensor + var output: Texture

let stride: [Int32] let paddings: [Int32] let dilations: [Int32] diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift index 2ea92546fa..5ed36f86d7 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift @@ -14,12 +14,13 @@ import Foundation -struct ElementwiseAddParam: OpParam { +class ElementwiseAddParam: OpParam { typealias ParamPrecisionType = P - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { do { input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope) - inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: inScope) + inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope) + output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) } catch let error { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift index 0249547416..1e12d8af0d 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift @@ -14,14 +14,14 @@ import Foundation -struct FeedParam: OpParam{ +class FeedParam: OpParam{ var output: Texture

var input: InputTexture { return scope.input() as! InputTexture } let scope: Scope - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { scope = inScope do { output = try FeedParam.outputOut(outputs: opDesc.outputs, from: inScope) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift index 0bbfb6de7e..3bc2bb5cb1 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift @@ -14,11 +14,11 @@ import Foundation -struct FetchParam: OpParam{ +class FetchParam: OpParam{ var output: ResultHolder

= ResultHolder.init(inDim: [], inResult: []) let input: Texture

let scope: Scope - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { scope = inScope do { input = try FetchParam.inputX(inputs: opDesc.inputs, from: inScope) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift index 1d190537cf..63961f026b 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift @@ -9,11 +9,59 @@ import Foundation class ConvAddBatchNormReluKernel: Kernel, Computable { - required init(device: MTLDevice, param: ConvParam

) { - super.init(device: device, inFunctionName: "conv3x3") + var metalParam: MetalConvParam! + + required init(device: MTLDevice, param: ConvAddBatchNormReluParam

) { + super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") + + let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0]) + let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1]) + let offsetZ = 0.0 + metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3])) + + var invs: [P] = [] + let varianceContents = param.variance.buffer.contents().assumingMemoryBound(to: P.self) + + for i in 0...stride { + let inv = pow(Float32.init(varianceContents[i]) + param.epsilon, 0.5) + invs.append(P(inv)) + } + + let newScale: UnsafeMutablePointer

= UnsafeMutablePointer

.allocate(capacity: param.scale.buffer.length) + let newBiase: UnsafeMutablePointer

= UnsafeMutablePointer

.allocate(capacity: param.bias.buffer.length) + let scaleContents = param.variance.buffer.contents().assumingMemoryBound(to: P.self) + let biaseContents = param.bias.buffer.contents().assumingMemoryBound(to: P.self) + let meanContents = param.mean.buffer.contents().assumingMemoryBound(to: P.self) + for i in 0...stride { + newScale[i] = invs[i] * scaleContents[i] + newBiase[i] = biaseContents[i] - meanContents[i] * invs[i] * scaleContents[i] + } + param.newBiase = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length) + param.newScale = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length) + + newScale.deinitialize(count: param.scale.buffer.length) + newScale.deallocate() + + newBiase.deinitialize(count: param.bias.buffer.length) + newBiase.deallocate() } - func compute(commandBuffer: MTLCommandBuffer, param: ConvParam

) throws { + func compute(commandBuffer: MTLCommandBuffer, param: ConvAddBatchNormReluParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + + print("ConvAddBatchNormReluKernel compute") + + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.setBuffer(param.filter.buffer, offset: 0, index: 1) + encoder.setBuffer(param.bias.buffer, offset: 0, index: 2) + encoder.setBuffer(param.newScale!, offset: 0, index: 3) + encoder.setBuffer(param.newBiase!, offset: 0, index: 4) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift new file mode 100644 index 0000000000..442be96f79 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift @@ -0,0 +1,19 @@ +// +// ConvKernel.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/5. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class ConvAddKernel: Kernel, Computable { + required init(device: MTLDevice, param: ConvAddParam

) { + super.init(device: device, inFunctionName: "conv3x3") + + } + + func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam

) throws { + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal index 595858dd81..ed3bf031eb 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal @@ -52,3 +52,48 @@ kernel void conv3x3(texture2d_array inTexture [[texture(0) } outTexture.write(output, gid.xy, gid.z); } + +kernel void conv_add_batch_norm_relu_3x3(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant MetalConvParam ¶m [[buffer(0)]], + const device half4 *weights [[buffer(1)]], + const device half4 *biase [[buffer(2)]], + const device half4 *new_scale [[buffer(3)]], + const device half4 *new_biase [[buffer(4)]], + uint3 gid [[thread_position_in_grid]]) { + + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) { + return; + } + + short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY); + constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); + const uint wightSliceCount = 36; + uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size(); + half4 output = 0.0; + for (uint i = 0; i < inTexture.get_array_size(); ++i) { + half4 input[9]; + input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i); + input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i); + input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i); + input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i); + input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i); + input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i); + input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i); + input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i); + input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i); + for (int j = 0; j < 9; ++j) { + half4 weight = weights[weithTo + wightSliceCount * i + j * 4]; + output += dot(input[j], weight); + } + } + + output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0h); + outTexture.write(output, gid.xy, gid.z); + +} + + + diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift new file mode 100644 index 0000000000..ce31b18f34 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/PoolKernel.swift @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class PoolKernel: Kernel, Computable{ + + func compute(commandBuffer: MTLCommandBuffer, param: PoolParam

) throws { + } + + required init(device: MTLDevice, param: PoolParam

) { + super.init(device: device, inFunctionName: "relu") + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift new file mode 100644 index 0000000000..ee2633b31c --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift @@ -0,0 +1,26 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class ReshapeKernel: Kernel, Computable{ + required init(device: MTLDevice, param: ReshapeParam

) { + super.init(device: device, inFunctionName: "relu") + } + + func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam

) throws { + } + + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift new file mode 100644 index 0000000000..a5e90b0a92 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SoftmaxKernel.swift @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class SoftmaxKernel: Kernel, Computable{ + + func compute(commandBuffer: MTLCommandBuffer, param: SoftmaxParam

) throws { + } + + required init(device: MTLDevice, param: SoftmaxParam

) { + super.init(device: device, inFunctionName: "relu") + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift new file mode 100644 index 0000000000..594aca05ad --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/PoolOp.swift @@ -0,0 +1,39 @@ +// +// PoolOp.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/9. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class PoolParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + input = try PoolParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try PoolParam.outputOut(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + let input: Texture

+ var output: Texture

+} + +class PoolOp: Operator, PoolParam

>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + + typealias OpType = PoolOp

+ func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift index 62a57eb447..f65e402cdd 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift @@ -14,9 +14,9 @@ import Foundation -struct ReluParam: OpParam { +class ReluParam: OpParam { typealias ParamPrecisionType = P - init(opDesc: OpDesc, inScope: Scope) throws { + required init(opDesc: OpDesc, inScope: Scope) throws { do { input = try ReluParam.inputX(inputs: opDesc.inputs, from: inScope) output = try ReluParam.outputOut(outputs: opDesc.outputs, from: inScope) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift new file mode 100644 index 0000000000..b783c6f051 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift @@ -0,0 +1,39 @@ +// +// PoolOp.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/9. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class ReshapeParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + input = try ReshapeParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + let input: Texture

+ var output: Texture

+} + +class ReshapeOp: Operator, ReshapeParam

>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + + typealias OpType = ReshapeOp

+ func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift new file mode 100644 index 0000000000..e925b81655 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/SoftmaxOp.swift @@ -0,0 +1,39 @@ +// +// PoolOp.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/9. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class SoftmaxParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + input = try SoftmaxParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try SoftmaxParam.outputOut(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + let input: Texture

+ var output: Texture

+} + +class SoftmaxOp: Operator, SoftmaxParam

>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + + typealias OpType = SoftmaxOp

+ func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift b/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift index f57d26fb6c..11ca4e8492 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/OpDesc.swift @@ -20,7 +20,7 @@ struct OpDesc { let outputs: [String : [String]] let unusedOutputs: [String : [String]] var attrs: [String : Attr] = [:] - let type: String + var type: String init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) { type = protoOpDesc.type let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in diff --git a/metal/paddle-mobile/paddle-mobile/Program/Program.swift b/metal/paddle-mobile/paddle-mobile/Program/Program.swift index a346af8304..1481677b19 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/Program.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/Program.swift @@ -18,8 +18,8 @@ public struct Program { let paramPath: String let programDesc: ProgramDesc let scope: Scope - init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) { - programDesc = ProgramDesc.init(protoProgram: protoProgramDesc) + init(inProgramDesc: ProgramDesc, inParamPath: String, inScope: Scope) { + programDesc = inProgramDesc paramPath = inParamPath scope = inScope } diff --git a/metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift b/metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift index a21bdb82f1..dc813537cb 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/ProgramOptimize.swift @@ -18,7 +18,7 @@ infix operator --> : ChainNode class Node { var inputs: [Node] = [] var outputs: [Node] = [] - let type: String + var type: String var opDesc: OpDesc? init(inOpDesc: OpDesc) { type = inOpDesc.type @@ -36,11 +36,12 @@ class Node { } func depth(begin: UInt = 1) -> UInt { - var beginMax: UInt = 0 + var beginMax: UInt = 1 for output in outputs { let subDepth = output.depth(begin: begin + 1) beginMax = max(begin, subDepth) } + beginMax = max(begin, beginMax) return beginMax } @@ -50,23 +51,26 @@ class Node { return beginNode } - func folderWith(fusion: Fusion.Type) { + func folderWith(fusion: Fusion.Type, removedNodes: inout [Node]) { let fusionNode = fusion.fusionNode() let change = fusion.change() let inOutputs = outputs outputs.removeAll() for i in 0.. { - let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp

.self] + let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp

.self, ConvAddOp

.self] func optimize(originProgramDesc: ProgramDesc) -> ProgramDesc { guard originProgramDesc.blocks.count == 1 else { @@ -141,7 +149,7 @@ class ProgramOptimize { guard let opInputKeys = opInfos[opDesc.type]?.inputs, let outputKeys = opInfos[opDesc.type]?.outputs else { fatalError() } - + let node = Node.init(inOpDesc: opDesc) for inputKey in opInputKeys { if let inputs = opDesc.inputs[inputKey] { @@ -164,28 +172,32 @@ class ProgramOptimize { nodes.append(node) - if var nodes = typeMapNodes[opDesc.type] { - nodes.append(node) - typeMapNodes[opDesc.type] = nodes + if var inNodes = typeMapNodes[opDesc.type] { + inNodes.append(node) + typeMapNodes[opDesc.type] = inNodes } else { - typeMapNodes[opDesc.type] = [] + typeMapNodes[opDesc.type] = [node] } } for fusion in fusionOps { let fusionNode = fusion.fusionNode() let depth = fusionNode.depth() - print(depth) - if let nodes = typeMapNodes[fusionNode.type] { - for node in nodes { - let toNode = node.to(depth: 4) + if let toMatchNodes = typeMapNodes[fusionNode.type] { + for node in toMatchNodes { + let toNode = node.to(depth: depth) if toNode == fusionNode { // match - node.folderWith(fusion: fusion) + var removeNodes: [Node] = [] + node.folderWith(fusion: fusion, removedNodes: &removeNodes) + for removeNode in removeNodes { + nodes.remove(element: removeNode) + } } + } } } - + var ops: [OpDesc] = [] for node in nodes { ops.append(node.opDesc!) -- GitLab