From 20f5ae65933e0085fc3c3cd6b123d52bd0b6aa2c Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Mon, 20 May 2019 20:01:21 +0800 Subject: [PATCH] 1.add flatten2, slice and nearest_interp op. 2.truncate tensor with more than 4 dimensions to 4. (#1634) --- .../project.pbxproj | 8 +++ .../paddle-mobile-metallib/Common.metal | 1 - .../paddle-mobile-metallib/ConcatKernel.metal | 43 +++++++++-- .../NearestInterpKernel.metal | 50 +++++++++++++ .../paddle-mobile-metallib/SliceKernel.metal | 63 ++++++++++++++++ .../paddle-mobile.xcodeproj/project.pbxproj | 16 +++++ .../Src/Operators/Base/OpCreator.swift | 3 + .../Src/Operators/Base/Operator.swift | 6 ++ .../Src/Operators/FlattenOp.swift | 50 +++++++++++-- .../Src/Operators/Kernels/FlattenKernel.swift | 56 ++++++++++++++- .../Kernels/NearestInterpKernel.swift | 49 +++++++++++++ .../Src/Operators/Kernels/SliceKernel.swift | 72 +++++++++++++++++++ .../Src/Operators/NearestInterpOp.swift | 63 ++++++++++++++++ .../Src/Operators/ReshapeOp.swift | 8 ++- .../paddle-mobile/Src/Operators/SliceOp.swift | 67 +++++++++++++++++ .../Src/Program/TensorDesc.swift | 13 ++++ 16 files changed, 554 insertions(+), 14 deletions(-) create mode 100644 metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal create mode 100644 metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/NearestInterpKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/NearestInterpOp.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Src/Operators/SliceOp.swift diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj b/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj index a7cd8fecba..1b097dfb99 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib.xcodeproj/project.pbxproj @@ -7,7 +7,9 @@ objects = { /* Begin PBXBuildFile section */ + 16324D862292C4930047277D /* NearestInterpKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16324D852292C4930047277D /* NearestInterpKernel.metal */; }; 165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; }; + 16D3F3BB22929EAD0067C45D /* SliceKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3BA22929EAD0067C45D /* SliceKernel.metal */; }; 16FBFB3E22925D040025B406 /* ActivationKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3D22925D040025B406 /* ActivationKernel.metal */; }; 5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; }; A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */; }; @@ -55,7 +57,9 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 16324D852292C4930047277D /* NearestInterpKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = NearestInterpKernel.metal; sourceTree = ""; }; 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = ""; }; + 16D3F3BA22929EAD0067C45D /* SliceKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = SliceKernel.metal; sourceTree = ""; }; 16FBFB3D22925D040025B406 /* ActivationKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ActivationKernel.metal; sourceTree = ""; }; 33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = ""; }; 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -199,6 +203,8 @@ A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */, 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */, 16FBFB3D22925D040025B406 /* ActivationKernel.metal */, + 16D3F3BA22929EAD0067C45D /* SliceKernel.metal */, + 16324D852292C4930047277D /* NearestInterpKernel.metal */, ); path = "paddle-mobile-metallib"; sourceTree = ""; @@ -316,12 +322,14 @@ FCC15E0B221E69E100DC3CB2 /* FetchKernel.inc.metal in Sources */, FCC15DEE221E69E100DC3CB2 /* ConvTransposeKernel.metal in Sources */, FCC15DFC221E69E100DC3CB2 /* ConvAddPreluKernel.metal in Sources */, + 16D3F3BB22929EAD0067C45D /* SliceKernel.metal in Sources */, FCC15E06221E69E100DC3CB2 /* BoxCoder.inc.metal in Sources */, FCC15DF1221E69E100DC3CB2 /* BilinearInterp.inc.metal in Sources */, FCC15E08221E69E100DC3CB2 /* Split.inc.metal in Sources */, FCC15DF4221E69E100DC3CB2 /* ResizeBilinear.metal in Sources */, FCC15E05221E69E100DC3CB2 /* BatchNormKernel.metal in Sources */, 165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */, + 16324D862292C4930047277D /* NearestInterpKernel.metal in Sources */, FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */, FCC15DF6221E69E100DC3CB2 /* PoolKernel.metal in Sources */, FCC15E09221E69E100DC3CB2 /* ConcatKernel.inc.metal in Sources */, diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/Common.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/Common.metal index 099b8ca77c..a25e354d71 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/Common.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/Common.metal @@ -117,4 +117,3 @@ struct MetalConvParam { ushort dilationX; ushort dilationY; }; - diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal index 8a0390e624..497b3585c0 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal @@ -67,6 +67,19 @@ struct ConcatParam { #undef R #undef V +// lens: (R=4, N=3, V=x) +#define V VX +#define R 4 +#define N 3 +#define P float +#include "ConcatKernel.inc.metal" +#undef P +#define P half +#include "ConcatKernel.inc.metal" +#undef P +#undef N +#undef R +#undef V // ssd-ar: (R=3, N=2, V=y) #define V VY @@ -96,6 +109,19 @@ struct ConcatParam { #undef R #undef V +// lens: (R=4, N=2, V=z) +#define V VZ +#define R 4 +#define N 2 +#define P float +#include "ConcatKernel.inc.metal" +#undef P +#define P half +#include "ConcatKernel.inc.metal" +#undef P +#undef N +#undef R +#undef V // ssd: (R=2, N=6, V=y) #define V VY @@ -138,6 +164,19 @@ struct ConcatParam { #undef R #undef V +// lens: (R=2, N=3, V=normal) +#define V VNORMAL +#define R 2 +#define N 3 +#define P float +#include "ConcatKernel.inc.metal" +#undef P +#define P half +#include "ConcatKernel.inc.metal" +#undef P +#undef N +#undef R +#undef V #define V VY #define R 2 @@ -165,7 +204,3 @@ struct ConcatParam { #undef N #undef R #undef V - - - - diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal new file mode 100644 index 0000000000..08d0d2dfa5 --- /dev/null +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +struct NearestInterpParam { + float scale; +}; + +kernel void nearest_interp(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant NearestInterpParam ¶m [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + float scale = param.scale; + uint x = uint(round(float(gid.x) / scale)); + uint y = uint(round(float(gid.y) / scale)); + const float4 input = inTexture.read(uint2(x, y), gid.z); + outTexture.write(input, gid.xy, gid.z); +} + +kernel void nearest_interp_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant NearestInterpParam ¶m [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + float scale = param.scale; + uint x = uint(round(float(gid.x) / scale)); + uint y = uint(round(float(gid.y) / scale)); + const half4 input = inTexture.read(uint2(x, y), gid.z); + outTexture.write(input, gid.xy, gid.z); +} diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal new file mode 100644 index 0000000000..acf61fefcd --- /dev/null +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal @@ -0,0 +1,63 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +struct MetalSliceParam { + short start0; + short start1; + short start2; + short start3; + short end0; + short end1; + short end2; + short end3; +}; + +kernel void slice(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant MetalSliceParam ¶m [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + float4 output; + for (int i = 0; i < 4; i++) { + int input_c = gid.z * 4 + i + param.start1; + int input_z = input_c / 4; + const float4 input = inTexture.read(gid.xy, input_z); + output[i] = input[input_c % 4]; + } + outTexture.write(output, gid.xy, gid.z); +} + +kernel void slice_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant MetalSliceParam ¶m [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + float4 output; + for (int i = 0; i < 4; i++) { + int input_c = gid.z * 4 + i + param.start1; + int input_z = input_c / 4; + const float4 input = float4(inTexture.read(gid.xy, input_z)); + output[i] = input[input_c % 4]; + } + outTexture.write(half4(output), gid.xy, gid.z); +} diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 1bad0ed855..3bb6743442 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -7,8 +7,12 @@ objects = { /* Begin PBXBuildFile section */ + 16324D842292ABDB0047277D /* NearestInterpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16324D832292ABDB0047277D /* NearestInterpKernel.swift */; }; 165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */; }; 165F38D52276CE7D0088E29F /* ConvAddReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */; }; + 16D3F3B522929C390067C45D /* SliceOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B422929C390067C45D /* SliceOp.swift */; }; + 16D3F3B722929C660067C45D /* NearestInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B622929C660067C45D /* NearestInterpOp.swift */; }; + 16D3F3B922929D070067C45D /* SliceKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16D3F3B822929D070067C45D /* SliceKernel.swift */; }; 16FBFB36229259C60025B406 /* ExpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB35229259C60025B406 /* ExpOp.swift */; }; 16FBFB3822925B030025B406 /* ExpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3722925B030025B406 /* ExpKernel.swift */; }; 16FBFB3A22925C3E0025B406 /* SigmoidKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */; }; @@ -113,8 +117,12 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 16324D832292ABDB0047277D /* NearestInterpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NearestInterpKernel.swift; sourceTree = ""; }; 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluOp.swift; sourceTree = ""; }; 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluKernel.swift; sourceTree = ""; }; + 16D3F3B422929C390067C45D /* SliceOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SliceOp.swift; sourceTree = ""; }; + 16D3F3B622929C660067C45D /* NearestInterpOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NearestInterpOp.swift; sourceTree = ""; }; + 16D3F3B822929D070067C45D /* SliceKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SliceKernel.swift; sourceTree = ""; }; 16FBFB35229259C60025B406 /* ExpOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpOp.swift; sourceTree = ""; }; 16FBFB3722925B030025B406 /* ExpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ExpKernel.swift; sourceTree = ""; }; 16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SigmoidKernel.swift; sourceTree = ""; }; @@ -353,6 +361,8 @@ 16FBFB35229259C60025B406 /* ExpOp.swift */, 16FBFB3B22925C800025B406 /* SigmoidOp.swift */, 16FBFB3F229266FE0025B406 /* LeakyReluOp.swift */, + 16D3F3B422929C390067C45D /* SliceOp.swift */, + 16D3F3B622929C660067C45D /* NearestInterpOp.swift */, ); path = Operators; sourceTree = ""; @@ -413,6 +423,8 @@ 16FBFB3722925B030025B406 /* ExpKernel.swift */, 16FBFB3922925C3E0025B406 /* SigmoidKernel.swift */, 16FBFB412292684E0025B406 /* LeakyReluKernel.swift */, + 16D3F3B822929D070067C45D /* SliceKernel.swift */, + 16324D832292ABDB0047277D /* NearestInterpKernel.swift */, ); path = Kernels; sourceTree = ""; @@ -567,12 +579,14 @@ FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */, + 16D3F3B522929C390067C45D /* SliceOp.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */, FC039BBB20E11CC20081E9F8 /* PMProgramDesc.swift in Sources */, FCE3A1AB2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */, + 16324D842292ABDB0047277D /* NearestInterpKernel.swift in Sources */, FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */, FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */, FC9797CB21D6102D00F2FD90 /* ResizeBilinearKernel.swift in Sources */, @@ -605,6 +619,7 @@ 4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */, FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */, FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */, + 16D3F3B722929C660067C45D /* NearestInterpOp.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, @@ -625,6 +640,7 @@ FC803BC1214CB77A0094B8E5 /* ConvAddPreluKernel.swift in Sources */, FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */, FCE3A1AF2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift in Sources */, + 16D3F3B922929D070067C45D /* SliceKernel.swift in Sources */, FCE9D7B7214F869000B520C3 /* Net.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift index 79943e82a3..405596c7d9 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/OpCreator.swift @@ -77,6 +77,9 @@ class OpCreator { gExpType : ExpOp

.creat, gSigmoidType : SigmoidOp

.creat, gLeakyReluType : LeakyReluOp

.creat, + gFlatten2Type : Flatten2Op

.creat, + gSliceType : SliceOp

.creat, + gNearestInterpType : NearestInterpOp

.creat, ] diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift index ff882d6cbf..eef5d857c7 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Base/Operator.swift @@ -188,6 +188,9 @@ let gRelu6Type = "relu6" let gExpType = "exp" let gSigmoidType = "sigmoid" let gLeakyReluType = "leaky_relu" +let gFlatten2Type = "flatten2" +let gSliceType = "slice" +let gNearestInterpType = "nearest_interp" let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), @@ -226,4 +229,7 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out gExpType : (inputs: ["X"], outputs: ["Out"]), gSigmoidType : (inputs: ["X"], outputs: ["Out"]), gLeakyReluType : (inputs: ["X"], outputs: ["Out"]), + gFlatten2Type : (inputs: ["X"], outputs: ["Out"]), + gSliceType : (inputs: ["Input"], outputs: ["Out"]), + gNearestInterpType : (inputs: ["X"], outputs: ["Out"]), ] diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/FlattenOp.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/FlattenOp.swift index f5d1004948..951e572a88 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/FlattenOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/FlattenOp.swift @@ -20,14 +20,14 @@ class FlattenParam: OpParam { do { input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope) output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope) - axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs) +// axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs) } catch let error { throw error } } - let input: Texture + var input: Texture var output: Texture - let axis: Int + var axis: Int = 0 } @@ -56,8 +56,44 @@ class FlattenOp: Operator, FlattenParam

: OpParam { + required init(opDesc: PMOpDesc, inScope: Scope) throws { + do { + input = try Flatten2Param.inputX(inputs: opDesc.inputs, from: inScope) + output = try Flatten2Param.outputOut(outputs: opDesc.outputs, from: inScope) + + let inDims = input.dim + guard inDims.cout() == 4 else { + fatalError("flatten2 can't handle dims not equal to 4") + } + let outDim = [inDims[0] * inDims[1], inDims[2] * inDims[3]] + output.dim = Dim.init(inDim: outDim) + } catch let error { + throw error + } + } + var input: Texture + var output: Texture +} - - - - +class Flatten2Op: Operator, Flatten2Param

>, Runable, Creator, InferShaperable { + typealias OpType = Flatten2Op

+ + func inferShape() { + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + let device = para.output.metalTexture!.device + let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose) + print(outputArray.strideArray()) + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FlattenKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FlattenKernel.swift index da115361ba..ec2b22709e 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FlattenKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/FlattenKernel.swift @@ -22,7 +22,7 @@ struct FlattenMetalParam { } -class FlattenKernel: Kernel, Computable{ +class FlattenKernel: Kernel, Computable { var metalParam: FlattenMetalParam @@ -75,3 +75,57 @@ class FlattenKernel: Kernel, Computable{ encoder.endEncoding() } } + +class Flatten2Kernel: Kernel, Computable { + + var metalParam: FlattenMetalParam + + required init(device: MTLDevice, param: Flatten2Param

, initContext: InitContext) throws { + + do { + try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision) + } catch let error { + throw error + } + + var id: [Int32] = [1, 1, 1, 1] + for i in 0..) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encoder is nil") + } + + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/NearestInterpKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/NearestInterpKernel.swift new file mode 100644 index 0000000000..d7ca978576 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/NearestInterpKernel.swift @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct NearestInterpMetalParam { + let scale: Float32 +} + +class NearestInterpKernel: Kernel, Computable { + var metalParam: NearestInterpMetalParam + func compute(commandBuffer: MTLCommandBuffer, param: NearestInterpParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: NearestInterpParam

, initContext: InitContext) throws { + do { + try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + } catch let error { + throw error + } + metalParam = NearestInterpMetalParam(scale: param.scale) + if GlobalConfig.shared.computePrecision == .Float32 { + super.init(device: device, inFunctionName: "nearest_interp", initContext: initContext) + } else if GlobalConfig.shared.computePrecision == .Float16 { + super.init(device: device, inFunctionName: "nearest_interp_half", initContext: initContext) + } else { + fatalError() + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift new file mode 100644 index 0000000000..565c974a86 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +public struct SliceMetalParam { + let start0: Int16 + let start1: Int16 + let start2: Int16 + let start3: Int16 + let end0: Int16 + let end1: Int16 + let end2: Int16 + let end3: Int16 +} + +class SliceKernel: Kernel, Computable { + var metalParam: SliceMetalParam + func compute(commandBuffer: MTLCommandBuffer, param: SliceParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: SliceParam

, initContext: InitContext) throws { + do { + try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + } catch let error { + throw error + } + var ranges = [[Int16]]() + for i in 0..<4 { + if let range = param.ranges[i] { + ranges.append(range) + } else { + ranges.append([0, Int16(param.input.tensorDim[i])]) + } + } + let start0 = ranges[0][0] + let start1 = ranges[1][0] + let start2 = ranges[2][0] + let start3 = ranges[3][0] + let end0 = ranges[0][1] + let end1 = ranges[1][1] + let end2 = ranges[2][1] + let end3 = ranges[3][1] + metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3) + if GlobalConfig.shared.computePrecision == .Float32 { + super.init(device: device, inFunctionName: "slice", initContext: initContext) + } else if GlobalConfig.shared.computePrecision == .Float16 { + super.init(device: device, inFunctionName: "slice_half", initContext: initContext) + } else { + fatalError() + } + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/NearestInterpOp.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/NearestInterpOp.swift new file mode 100644 index 0000000000..48259c7fc1 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/NearestInterpOp.swift @@ -0,0 +1,63 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class NearestInterpParam: OpParam { + required init(opDesc: PMOpDesc, inScope: Scope) throws { + do { + input = try NearestInterpParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try NearestInterpParam.outputOut(outputs: opDesc.outputs, from: inScope) + let inputDim = input.tensorDim + let outputDim = output.tensorDim + guard inputDim.cout() == 4 && outputDim.cout() == 4 && inputDim[0] == outputDim[0] && inputDim[1] == outputDim[1] else { + fatalError("nearest interp only support scale along width and height") + } + let scaleX = Float32(outputDim[2]) / Float32(inputDim[2]) + let scaleY = Float32(outputDim[3]) / Float32(inputDim[3]) + guard abs(scaleX - scaleY) <= 0.00001 else { + fatalError("nearest interp only support same scale factor") + } + scale = scaleX + print("ok") + } catch let error { + throw error + } + } + var input: Texture + var output: Texture + var scale: Float32 +} + +class NearestInterpOp: Operator, NearestInterpParam

>, Runable, Creator, InferShaperable { + typealias OpType = NearestInterpOp

+ + func inferShape() { + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + let device = para.output.metalTexture!.device + let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose) + print(outputArray.strideArray()) + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/ReshapeOp.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/ReshapeOp.swift index acff1c95ea..6cb4934f06 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/ReshapeOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/ReshapeOp.swift @@ -23,6 +23,12 @@ class ReshapeParam: OpParam { output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope) shape = try ReshapeParam.getAttr(key: "shape", attrs: opDesc.attrs) + if shape.count > 4 { + if shape[0] == -1 { + shape.removeFirst() + } + } + var s: [Int] = shape.map { Int($0) } var di = -1 @@ -49,7 +55,7 @@ class ReshapeParam: OpParam { } } let input: Texture - let shape: [Int32] + var shape: [Int32] var output: Texture } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/SliceOp.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/SliceOp.swift new file mode 100644 index 0000000000..411eab936d --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/SliceOp.swift @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class SliceParam: OpParam { + //typealias ParamPrecisionType = P + required init(opDesc: PMOpDesc, inScope: Scope) throws { + do { + input = try SliceParam.input(inputs: opDesc.inputs, from: inScope) + output = try SliceParam.outputOut(outputs: opDesc.outputs, from: inScope) + starts = try SliceParam.getAttr(key: "starts", attrs: opDesc.attrs) + ends = try SliceParam.getAttr(key: "ends", attrs: opDesc.attrs) + for i in 0..: Operator, SliceParam

>, Runable, Creator, InferShaperable { + typealias OpType = SliceOp

+ + func inferShape() { + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print("\(type) output : ") + print(para.output.toTensor().strideArray()) + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Src/Program/TensorDesc.swift b/metal/paddle-mobile/paddle-mobile/Src/Program/TensorDesc.swift index 7565fffc99..badca3dbac 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Program/TensorDesc.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Program/TensorDesc.swift @@ -62,6 +62,19 @@ class TensorDesc { let dim = Int(protoTensorDesc.dimsArray.value(at: i)) > 0 ?Int(protoTensorDesc.dimsArray.value(at: i)) :abs(Int(protoTensorDesc.dimsArray.value(at: i))) dimsArray.append(dim) } + + if dimsCount > 4 { + let headDims = Int(dimsCount - 4) + for i in 0..