From 857ffcbcd013413c566c599a3a495e00267eb98c Mon Sep 17 00:00:00 2001 From: NazgulLee Date: Thu, 16 May 2019 22:02:48 +0800 Subject: [PATCH] add reshape2, transpose2, relu6 and scale operator (#1628) --- .../project.pbxproj | 4 + .../paddle-mobile-metallib/ReluKernel.metal | 31 +++++++ .../paddle-mobile-metallib/ScaleKernel.metal | 82 +++++++++++++++++++ .../TransposeKernel.inc.metal | 2 +- .../paddle-mobile.xcodeproj/project.pbxproj | 16 ++++ .../Src/Operators/Base/OpCreator.swift | 8 +- .../Src/Operators/Base/Operator.swift | 12 ++- .../Src/Operators/Kernels/Relu6Kernel.swift | 49 +++++++++++ .../Src/Operators/Kernels/ScaleOpKernel.swift | 61 ++++++++++++++ .../paddle-mobile/Src/Operators/Relu6Op.swift | 56 +++++++++++++ .../paddle-mobile/Src/Operators/ScaleOp.swift | 57 +++++++++++++ .../paddle-mobile/Src/Program/PMOpDesc.swift | 4 + 12 files changed, 378 insertions(+), 4 deletions(-) create mode 100644 metal/paddle-mobile-metallib/paddle-mobile-metallib/ScaleKernel.metal create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Relu6Kernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ScaleOpKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/Relu6Op.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/ScaleOp.swift diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj b/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj index a8cdfbc293..29d94a7235 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj @@ -9,6 +9,7 @@ /* Begin PBXBuildFile section */ 165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; }; 5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; }; + A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */; }; FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; }; FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBD221E69DD00DC3CB2 /* BoxCoder.metal */; }; FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBE221E69DD00DC3CB2 /* ConvAddBNReluKernel.metal */; }; @@ -56,6 +57,7 @@ 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = ""; }; 33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = ""; }; 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ScaleKernel.metal; sourceTree = ""; }; C6D31B9F9533810DBCA6B28D /* Pods-paddle-mobile-metallib.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.debug.xcconfig"; sourceTree = ""; }; FCC15D60221E66DE00DC3CB2 /* paddle-mobile-metallib.metallib */ = {isa = PBXFileReference; explicitFileType = "archive.metal-library"; includeInIndex = 0; path = "paddle-mobile-metallib.metallib"; sourceTree = BUILT_PRODUCTS_DIR; }; FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ReluKernel.metal; sourceTree = ""; }; @@ -192,6 +194,7 @@ FCC15DBF221E69DD00DC3CB2 /* Split.metal */, FCC15DC9221E69DE00DC3CB2 /* TransposeKernel.inc.metal */, FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */, + A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */, 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */, ); path = "paddle-mobile-metallib"; @@ -284,6 +287,7 @@ FCC15DE8221E69E100DC3CB2 /* Split.metal in Sources */, FCC15DF2221E69E100DC3CB2 /* TransposeKernel.inc.metal in Sources */, FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */, + A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */, FCC15E04221E69E100DC3CB2 /* ConvAddPrelu.inc.metal in Sources */, FCC15DF9221E69E100DC3CB2 /* Kernels.metal in Sources */, FCC15DF0221E69E100DC3CB2 /* PreluKernel.metal in Sources */, diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ReluKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ReluKernel.metal index 725222d75e..09882fd507 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ReluKernel.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ReluKernel.metal @@ -15,6 +15,9 @@ #include using namespace metal; +struct Relu6Param { + float threshold; +}; kernel void relu_half(texture2d_array inTexture [[texture(0)]], texture2d_array outTexture [[texture(1)]], @@ -39,3 +42,31 @@ kernel void relu(texture2d_array inTexture [[texture(0)]] const float4 relu = fmax((float4)input, 0.0); outTexture.write(float4(relu), gid.xy, gid.z); } + +kernel void relu6_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant Relu6Param &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + const float threshold = pm.threshold; + const float4 relu = fmin(fmax((float4)input, 0.0), threshold); + outTexture.write(half4(relu), gid.xy, gid.z); +} + +kernel void relu6(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant Relu6Param &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const float4 input = inTexture.read(gid.xy, gid.z); + const float threshold = pm.threshold; + const float4 relu = fmin(fmax((float4)input, 0.0), threshold); + outTexture.write(float4(relu), gid.xy, gid.z); +} diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ScaleKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ScaleKernel.metal new file mode 100644 index 0000000000..d494c815f1 --- /dev/null +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ScaleKernel.metal @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "Common.metal" +using namespace metal; + +struct ScaleParam { + float scale; + float abias; +}; + +kernel void scale_before_bias_float(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant ScaleParam &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const float4 input = inTexture.read(gid.xy, gid.z); + const float scale = pm.scale; + const float abias = pm.abias; + const float4 output = scale * input + abias; + outTexture.write(output, gid.xy, gid.z); +} + +kernel void scale_after_bias_float(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant ScaleParam &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const float4 input = inTexture.read(gid.xy, gid.z); + const float scale = pm.scale; + const float abias = pm.abias; + const float4 output = scale * (input + abias); + outTexture.write(output, gid.xy, gid.z); +} + +kernel void scale_before_bias_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant ScaleParam &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + const float scale = pm.scale; + const float abias = pm.abias; + const float4 output = scale * (float4)input + abias; + outTexture.write(half4(output), gid.xy, gid.z); +} + +kernel void scale_after_bias_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant ScaleParam &pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + const float scale = pm.scale; + const float abias = pm.abias; + const float4 output = scale * ((float4)input + abias); + outTexture.write(half4(output), gid.xy, gid.z); +} diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/TransposeKernel.inc.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/TransposeKernel.inc.metal index d80361da46..8fa73b2011 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/TransposeKernel.inc.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/TransposeKernel.inc.metal @@ -31,7 +31,7 @@ kernel void FUNC(transpose, R, P)(texture2d_array inTexture [[t for (int n = 0; n < 4; n++) { oxyzn[3] = n; #if R == 4 - xyzn2abcd_4(pm.oC, oxyzn, iabcd); + xyzn2abcd_4(pm.oC, oxyzn, oabcd); #endif // R == 4 #if R == 3 xyzn2abcd_3(oxyzn, oabcd); diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 8f258aef45..3cf49082ef 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -19,6 +19,10 @@ 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; + A73DC749227F1C7A001EB663 /* ScaleOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = A73DC748227F1C7A001EB663 /* ScaleOp.swift */; }; + A73DC74B227F1EDE001EB663 /* ScaleOpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */; }; + A7F26FDA22842EF200365D47 /* Relu6Op.swift in Sources */ = {isa = PBXBuildFile; fileRef = A7F26FD922842EF200365D47 /* Relu6Op.swift */; }; + A7F26FDC2284301500365D47 /* Relu6Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = A7F26FDB2284301500365D47 /* Relu6Kernel.swift */; }; C28FE02F21BA68C00054EFAC /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02C21BA68C00054EFAC /* Metal.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; C28FE03021BA68C00054EFAC /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; C28FE03121BA68C00054EFAC /* MetalKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02E21BA68C00054EFAC /* MetalKit.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; @@ -115,6 +119,10 @@ 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = ""; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = ""; }; 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; + A73DC748227F1C7A001EB663 /* ScaleOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ScaleOp.swift; sourceTree = ""; }; + A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ScaleOpKernel.swift; sourceTree = ""; }; + A7F26FD922842EF200365D47 /* Relu6Op.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Relu6Op.swift; sourceTree = ""; }; + A7F26FDB2284301500365D47 /* Relu6Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Relu6Kernel.swift; sourceTree = ""; }; C28FE02C21BA68C00054EFAC /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; }; C28FE02E21BA68C00054EFAC /* MetalKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalKit.framework; path = System/Library/Frameworks/MetalKit.framework; sourceTree = SDKROOT; }; @@ -328,6 +336,8 @@ FCE3A1A82153DE5100C37CDE /* ConvAddAddPreluOp.swift */, FCE3A1AC2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift */, 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */, + A73DC748227F1C7A001EB663 /* ScaleOp.swift */, + A7F26FD922842EF200365D47 /* Relu6Op.swift */, ); path = Operators; sourceTree = ""; @@ -383,6 +393,8 @@ FC2BFD4521DF685F00C262B2 /* Scale.swift */, FCB40E5821E0DCAB0075EC91 /* FetchKernel.swift */, 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */, + A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */, + A7F26FDB2284301500365D47 /* Relu6Kernel.swift */, ); path = Kernels; sourceTree = ""; @@ -530,6 +542,7 @@ files = ( FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, + A73DC74B227F1EDE001EB663 /* ScaleOpKernel.swift in Sources */, 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */, FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, @@ -594,6 +607,7 @@ FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */, + A7F26FDA22842EF200365D47 /* Relu6Op.swift in Sources */, FCBCCC612122FBDF00D94F7E /* PriorBoxKernel.swift in Sources */, FCBCCC5F2122FB3B00D94F7E /* PriorBoxOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, @@ -607,6 +621,7 @@ FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, FCE3A1AD2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift in Sources */, FC039BC020E11CC20081E9F8 /* PMBlockDesc.swift in Sources */, + A7F26FDC2284301500365D47 /* Relu6Kernel.swift in Sources */, FCD04E6820F315020007374F /* PoolKernel.swift in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */, @@ -615,6 +630,7 @@ 4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */, FC2BFD4621DF685F00C262B2 /* Scale.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, + A73DC749227F1C7A001EB663 /* ScaleOp.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift index bbd726cc0c..ffb7657a35 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift @@ -69,7 +69,13 @@ class OpCreator { gConvAddAddPreluType : ConvAddAddPreluOp

.creat, gElementwiseAddPreluType : ElementwiseAddPreluOp

.creat, gFusionConvAddType : ConvAddOp

.creat, - gConvAddReluType : ConvAddReluOp

.creat] + gConvAddReluType : ConvAddReluOp

.creat, + gReshape2Type : ReshapeOp

.creat, + gTranspose2Type : TransposeOp

.creat, + gScaleType : ScaleOp

.creat, + gRelu6Type : Relu6Op

.creat + ] + private init(){} } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift index 85474cb5a9..a8052b245a 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift @@ -176,11 +176,15 @@ let gBilinearInterpType = "bilinear_interp" let gSplit = "split" let gShape = "shape" let gFlatten = "flatten" -let gConvAddReluType = "conv_add_relu" +let gConvAddReluType = "conv_add_relu" let gConvAddPreluType = "conv_add_prelu" let gConvAddAddPreluType = "conv_add_add_prelu" let gElementwiseAddPreluType = "elementwise_add_prelu" let gFusionConvAddType = "fusion_conv_add" +let gReshape2Type = "reshape2" +let gTranspose2Type = "transpose2" +let gScaleType = "scale" +let gRelu6Type = "relu6" let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), @@ -211,5 +215,9 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out gConvAddPreluType : (inputs: ["Input"], outputs: ["Out"]), gConvAddAddPreluType : (inputs: ["Input"], outputs: ["Out"]), gElementwiseAddPreluType : (inputs: ["X"], outputs: ["Out"]), - gFusionConvAddType : (inputs: ["Input"], outputs: ["Out"]) + gFusionConvAddType : (inputs: ["Input"], outputs: ["Out"]), + gReshape2Type : (inputs: ["X"], outputs: ["Out"]), + gTranspose2Type : (inputs: ["X"], outputs: ["Out"]), + gScaleType : (inputs: ["X"], outputs: ["Out"]), + gRelu6Type : (inputs: ["X"], outputs: ["Out"]), ] diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Relu6Kernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Relu6Kernel.swift new file mode 100644 index 0000000000..2152147e9f --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Relu6Kernel.swift @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct Relu6MetalParam { + let threshold: Float32 +} + +class Relu6Kernel: Kernel, Computable{ + var metalParam: Relu6MetalParam + func compute(commandBuffer: MTLCommandBuffer, param: Relu6Param

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: Relu6Param

, initContext: InitContext) throws { + do { + try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + } catch let error { + throw error + } + metalParam = Relu6MetalParam(threshold: param.threshold) + if GlobalConfig.shared.computePrecision == .Float32 { + super.init(device: device, inFunctionName: "relu6", initContext: initContext) + } else if GlobalConfig.shared.computePrecision == .Float16 { + super.init(device: device, inFunctionName: "relu6_half", initContext: initContext) + } else { + fatalError() + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ScaleOpKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ScaleOpKernel.swift new file mode 100644 index 0000000000..c56bb844ab --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ScaleOpKernel.swift @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct ScaleMetalParam { + let scale: Float32 + let abias: Float32 +} + +class ScaleOpKernel: Kernel, Computable{ + var metalParam: ScaleMetalParam + required init(device: MTLDevice, param: ScaleParam

, initContext: InitContext) throws { + do { + try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + } catch let error { + throw error + } + + metalParam = ScaleMetalParam(scale: param.scale, abias: param.bias) + + if GlobalConfig.shared.computePrecision == .Float32 { + if param.biasAfterScale { + super.init(device: device, inFunctionName: "scale_before_bias_float", initContext: initContext) + } else { + super.init(device: device, inFunctionName: "scale_after_bias_float", initContext: initContext) + } + } else if GlobalConfig.shared.computePrecision == .Float16 { + if param.biasAfterScale { + super.init(device: device, inFunctionName: "scale_before_bias_half", initContext: initContext) + } else { + super.init(device: device, inFunctionName: "scale_after_bias_half", initContext: initContext) + } + } else { + fatalError() + } + } + + func compute(commandBuffer: MTLCommandBuffer, param: ScaleParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encoder is nil") + } + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Relu6Op.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Relu6Op.swift new file mode 100644 index 0000000000..0eaeb9b503 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Relu6Op.swift @@ -0,0 +1,56 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + + +import Foundation + +class Relu6Param: OpParam { + required init(opDesc: PMOpDesc, inScope: Scope) throws { + do { + input = try Relu6Param.inputX(inputs: opDesc.inputs, from: inScope) + output = try Relu6Param.outputOut(outputs: opDesc.outputs, from: inScope) + threshold = try Relu6Param.getAttr(key: "threshold", attrs: opDesc.attrs) + } catch let error { + throw error + } + } + let input: Texture + var output: Texture + let threshold: Float32 +} + +class Relu6Op: Operator, Relu6Param

>, Runable, Creator, InferShaperable { + typealias OpType = Relu6Op

+ + func inferShape() { + para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + print(para.output.metalTexture) + print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) + } +} + + + diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/ScaleOp.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/ScaleOp.swift new file mode 100644 index 0000000000..31f8e4550a --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/ScaleOp.swift @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class ScaleParam: OpParam { + required init(opDesc: PMOpDesc, inScope: Scope) throws { + do { + input = try ScaleParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try ScaleParam.outputOut(outputs: opDesc.outputs, from: inScope) + scale = try ScaleParam.getAttr(key: "scale", attrs: opDesc.attrs) + bias = try ScaleParam.getAttr(key: "bias", attrs: opDesc.attrs) + biasAfterScale = try ScaleParam.getAttr(key: "bias_after_scale", attrs: opDesc.attrs) + } catch let error { + throw error + } + } + + let input: Texture + var output: Texture + let scale: Float32 + let bias: Float32 + let biasAfterScale: Bool +} + +class ScaleOp: Operator, ScaleParam

>, Runable, Creator, InferShaperable{ + typealias OpType = ScaleOp

+ + func inferShape() { + para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + print(para.output.metalTexture) + print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Program/PMOpDesc.swift b/metal/paddle-mobile/paddle-mobile/Src/Program/PMOpDesc.swift index 51a9e6be2f..f64a2da1f2 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Program/PMOpDesc.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Program/PMOpDesc.swift @@ -33,6 +33,10 @@ class PMOpDesc { return map } + guard let _ = opInfos[protoOpDesc.type] else { + fatalError() + } + inputs = creator(protoOpDesc.inputsArray as! [OpDesc_Var]) { opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false } -- GitLab