From 6ae74fd487337ccd9201346b6be9bc5e856691fd Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Wed, 12 Sep 2018 13:34:01 +0800 Subject: [PATCH] concat --- .../paddle-mobile.xcodeproj/project.pbxproj | 12 +- .../Operators/Kernels/metal/Common.metal | 49 ++++++++ .../Operators/Kernels/metal/Concat.metal | 116 ------------------ .../Kernels/metal/ConcatKernel.metal | 57 +++++++++ .../Kernels/metal/ConcatKernel.metal.inc | 82 +++++++++++++ 5 files changed, 196 insertions(+), 120 deletions(-) delete mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Concat.metal create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index fe260157cd..c384195cb4 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -17,9 +17,10 @@ 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; + 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; - 4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; }; + 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; 4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; }; D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; }; FC0226562138F33800F395E2 /* TransposeKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC0226552138F33800F395E2 /* TransposeKernel.metal */; }; @@ -124,9 +125,10 @@ 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = ""; }; 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = ""; }; + 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.pascal; path = ConcatKernel.metal.inc; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; - 4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = ""; }; + 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = ""; }; 4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = ""; }; CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = ""; }; DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -434,7 +436,8 @@ isa = PBXGroup; children = ( FC27990D21341016000B6BAD /* BoxCoder.metal */, - 4AF928812135673D005B6C3A /* Concat.metal */, + 4AF928812135673D005B6C3A /* ConcatKernel.metal */, + 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */, 4AF9288321357BE3005B6C3A /* Elementwise.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, @@ -468,6 +471,7 @@ FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */, FC292C85214257CB00CF622F /* CPUCompute.h in Headers */, FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */, + 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */, FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */, ); runOnlyForDeploymentPostprocessing = 0; @@ -621,7 +625,7 @@ FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC292C872142624800CF622F /* Genet.swift in Sources */, - 4AF928822135673D005B6C3A /* Concat.metal in Sources */, + 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */, FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */, FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Common.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Common.metal index da703d163f..9858cf9c3c 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Common.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Common.metal @@ -15,6 +15,55 @@ #include using namespace metal; + +inline void xyzn2abcd_1(int xyzn[4], int abcd[4]) { + abcd[0] = abcd[1] = abcd[2] = 1; + abcd[3] = xyzn[0] * 4 + xyzn[3]; +} +inline void xyzn2abcd_2(int xyzn[4], int abcd[4]) { + abcd[0] = abcd[1] = 1; + abcd[2] = xyzn[1]; + abcd[3] = xyzn[0] * 4 + xyzn[3]; +} +inline void xyzn2abcd_3(int xyzn[4], int abcd[4]) { + abcd[0] = 1; + abcd[3] = xyzn[0]; + abcd[2] = xyzn[1]; + abcd[1] = xyzn[2] * 4 + xyzn[3]; +} +inline void xyzn2abcd_4(int C, int xyzn[4], int abcd[4]) { + abcd[2] = xyzn[0]; + abcd[1] = xyzn[1]; + uint t = xyzn[2] * 4 + xyzn[3]; + abcd[0] = t / C; + abcd[3] = t % C; +} + +inline void abcd2xyzn_1(int abcd[4], int xyzn[4]) { + xyzn[1] = xyzn[2] = 1; + xyzn[0] = abcd[3] / 4; + xyzn[1] = abcd[3] % 4; +} +inline void abcd2xyzn_2(int abcd[4], int xyzn[4]) { + xyzn[2] = 1; + xyzn[1] = abcd[2]; + xyzn[0] = abcd[3] / 4; + xyzn[1] = abcd[3] % 4; +} +inline void abcd2xyzn_3(int abcd[4], int xyzn[4]) { + xyzn[0] = abcd[3]; + xyzn[1] = abcd[2]; + xyzn[2] = abcd[1] / 4; + xyzn[3] = abcd[1] % 4; +} +inline void abcd2xyzn_4(int C, int abcd[4], int xyzn[4]) { + xyzn[0] = abcd[2]; + xyzn[1] = abcd[1]; + uint t = abcd[0] * C + abcd[3]; + xyzn[2] = t / 4; + xyzn[3] = t % 4; +} + inline void xyzn2abcd(int C, int xyzn[4], int abcd[4]) { abcd[2] = xyzn[0]; abcd[1] = xyzn[1]; diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Concat.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Concat.metal deleted file mode 100644 index 92d80c315e..0000000000 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Concat.metal +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include -#include "Common.metal" - -using namespace metal; - -struct ConcatParam { - int32_t odim[4]; - int32_t axis; - int32_t offset; - int32_t trans[4]; - int32_t vdim[6]; -}; - -kernel void concat(texture2d_array in0 [[texture(0)]], - texture2d_array in1 [[texture(1)]], - texture2d_array in2 [[texture(2)]], - texture2d_array in3 [[texture(3)]], - texture2d_array in4 [[texture(4)]], - texture2d_array in5 [[texture(5)]], - texture2d_array inx [[texture(6)]], - texture2d_array out [[texture(7)]], - constant ConcatParam & pm [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - ConcatParam cp = pm; - int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; - float4 r; - for (int i = 0; i < 4; i++) { - xyzn[3] = i; - xyzn2abcd(cp.odim[3], xyzn, abcd); - int k = abcd[cp.axis] - cp.offset; - int j = 0; - if (k < 0) { - r[i] = inx.read(gid.xy, gid.z)[i]; - } else { - for (; j < 6; j++) { - if (k < cp.vdim[j]) { - break; - } - k -= cp.vdim[j]; - } - int ta = cp.odim[cp.axis]; - abcd[cp.axis] = k; - cp.odim[cp.axis] = cp.vdim[j]; - abcd2xyzn(cp.odim[3], abcd, oxyzn); - cp.odim[cp.axis] = ta; - switch (j) { - case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - } - } - } - out.write(r, gid.xy, gid.z); -} - -kernel void concat_half(texture2d_array in0 [[texture(0)]], - texture2d_array in1 [[texture(1)]], - texture2d_array in2 [[texture(2)]], - texture2d_array in3 [[texture(3)]], - texture2d_array in4 [[texture(4)]], - texture2d_array in5 [[texture(5)]], - texture2d_array inx [[texture(6)]], - texture2d_array out [[texture(7)]], - constant ConcatParam & pm [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - ConcatParam cp = pm; - int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; - half4 r; - for (int i = 0; i < 4; i++) { - xyzn[3] = i; - xyzn2abcd(cp.odim[3], xyzn, abcd); - int k = abcd[cp.axis] - cp.offset; - int j = 0; - if (k < 0) { - r[i] = inx.read(gid.xy, gid.z)[i]; - } else { - for (; j < 6; j++) { - if (k < cp.vdim[j]) { - break; - } - k -= cp.vdim[j]; - } - int ta = cp.odim[cp.axis]; - abcd[cp.axis] = k; - cp.odim[cp.axis] = cp.vdim[j]; - abcd2xyzn(cp.odim[3], abcd, oxyzn); - cp.odim[cp.axis] = ta; - switch (j) { - case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; - } - } - } - out.write(r, gid.xy, gid.z); -} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal new file mode 100644 index 0000000000..b874a74663 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "Common.metal" + +using namespace metal; + +struct ConcatParam { + int32_t odim[4]; + int32_t axis; + int32_t offset; + int32_t trans[4]; + int32_t vdim[6]; +}; + + +#define P float +#define D 4 +#include "Concat.metal.inc" +#undef D +#define D 3 +#include "Concat.metal.inc" +#undef D +#define D 2 +#include "Concat.metal.inc" +#undef D +#define D 1 +#include "Concat.metal.inc" +#undef D +#undef P + +#define P half +#define D 4 +#include "Concat.metal.inc" +#undef D +#define D 3 +#include "Concat.metal.inc" +#undef D +#define D 2 +#include "Concat.metal.inc" +#undef D +#define D 1 +#include "Concat.metal.inc" +#undef D +#undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc new file mode 100644 index 0000000000..0e8a4e019a --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#ifndef D +#define D 4 +#endif + +#ifndef P +#define P float +#endif + +#define CONCAT2(a, b) a ## b +#define CONCAT2_(a, b) a ## _ ## b +#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c + +#define FUNC(f, d, p) CONCAT3_(f, d, p) +#define VECTOR(p, n) CONCAT2(p, n) +#define FUNC_D(f, d) CONCAT2_(f, d) + +kernel void FUNC(concat, D, P)(texture2d_array in0 [[texture(0)]], + texture2d_array in1 [[texture(1)]], + texture2d_array in2 [[texture(2)]], + texture2d_array in3 [[texture(3)]], + texture2d_array in4 [[texture(4)]], + texture2d_array in5 [[texture(5)]], + texture2d_array inx [[texture(6)]], + texture2d_array out [[texture(7)]], + constant ConcatParam & pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + ConcatParam cp = pm; + int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; + VECTOR(P, 4) r; + for (int i = 0; i < 4; i++) { + xyzn[3] = i; +#if D == 4 + xyzn2abcd_4(cp.odim[3], xyzn, abcd); +#else + FUNC_D(xyzn2abcd, D)(xyzn, abcd); +#endif + int k = abcd[cp.axis] - cp.offset; + int j = 0; + if (k < 0) { + r[i] = inx.read(gid.xy, gid.z)[i]; + } else { + for (; j < 6; j++) { + if (k < cp.vdim[j]) { + break; + } + k -= cp.vdim[j]; + } + int ta = cp.odim[cp.axis]; + abcd[cp.axis] = k; + cp.odim[cp.axis] = cp.vdim[j]; +#if D == 4 + abcd2xyzn_4(cp.odim[3], abcd, oxyzn); +#else + FUNC_D(abcd2xyzn, D)(abcd, oxyzn); +#endif + cp.odim[cp.axis] = ta; + switch (j) { + case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; + } + } + } + out.write(r, gid.xy, gid.z); +} -- GitLab