diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 5eb0b576ceb9a74c92385bc3b3ce70e88ea0abea..b654a85bece77b8e96130f733652557c923dfa6f 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -20,6 +20,7 @@ 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; }; 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; 4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; }; + 4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA5214B5F6800D0F791 /* Shape.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; @@ -132,6 +133,7 @@ 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = ""; }; 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = ""; }; + 4AA1EAA5214B5F6800D0F791 /* Shape.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Shape.metal; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = ""; }; @@ -170,7 +172,6 @@ FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BatchNormKernel.swift; sourceTree = ""; }; FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = ""; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = ""; }; - FC27990D21341016000B6BAD /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; FC292C5321421B2E00CF622F /* PaddleMobileGPU.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PaddleMobileGPU.h; sourceTree = ""; }; FC292C5521421B4600CF622F /* PaddleMobileGPU.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = PaddleMobileGPU.m; sourceTree = ""; }; FC292C7C214255BC00CF622F /* CPUCompute.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = ""; }; @@ -445,13 +446,13 @@ FCEB6837212F00B100D2448E /* metal */ = { isa = PBXGroup; children = ( - FC27990D21341016000B6BAD /* BoxCoder.metal */, 4AF928812135673D005B6C3A /* ConcatKernel.metal */, 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */, 4AF9288321357BE3005B6C3A /* Elementwise.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */, + 4AA1EAA5214B5F6800D0F791 /* Shape.metal */, 4AA1EA8F214664CD00D0F791 /* Split.metal */, 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */, 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */, @@ -627,6 +628,7 @@ 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */, 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, + 4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */, FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift index 82c8dc4d92c31b8f809bc17ce2ea50cca8291d0c..f64d71ff015f47e889728ce502470724a1d2cade 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift @@ -30,9 +30,9 @@ class ShapeKernel: Kernel, Computable{ required init(device: MTLDevice, param: ShapeParam

) { param.output.initTexture(device: device, computePrecision: computePrecision) if computePrecision == .Float32 { - super.init(device: device, inFunctionName: "split") + super.init(device: device, inFunctionName: "shape") } else if computePrecision == .Float16 { - super.init(device: device, inFunctionName: "split_half") + super.init(device: device, inFunctionName: "shape_half") } else { fatalError() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift index c9944752be0fe1d878d9dbe173e635546176ddcc..e2d36049d6ee601857c8c6ae04862a21bf49b962 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift @@ -82,7 +82,7 @@ class SplitKernel: Kernel, Computable{ fatalError("split unsupported") } if computePrecision == .Float32 { - super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)") + super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_float") } else if computePrecision == .Float16 { super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_half") } else { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Shape.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Shape.metal new file mode 100644 index 0000000000000000000000000000000000000000..b50d5547193ccc9a1bef1b3ed6bbd1b7a64c3527 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Shape.metal @@ -0,0 +1,21 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +kernel void shape() { +} +kernel void shape_half() { +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal index 7532f35cbbb17efa33e209b0d126b384d86bfc57..4e1ab16cd7479f34fae578f7d914af061391fd12 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal @@ -10,7 +10,17 @@ #define VECTOR(p, n) CONCAT2(p, n) #define FUNC_R(f, r) CONCAT2_(f, r) -kernel void FUNC(split, R, N, V, P)(texture2d_array input [[texture(0)]], +#if V == VX +#define VV x +#elif V == VY +#define VV y +#elif V == VZ +#define VV z +#else +#define VV normal +#endif + +kernel void FUNC(split, R, N, VV, P)(texture2d_array input [[texture(0)]], texture2d_array out1 [[texture(1)]], texture2d_array out2 [[texture(2)]], #if N >= 3 @@ -23,7 +33,7 @@ kernel void FUNC(split, R, N, V, P)(texture2d_array input [[tex uint3 gid [[thread_position_in_grid]]) { VECTOR(P, 4) r = input.read(gid.xy, gid.z); -#if V == y +#if V == VY int y = gid.y - sp.offset; if (y < sp.vdim[0]) { out1.write(r, gid.xy, gid.z); @@ -47,7 +57,7 @@ kernel void FUNC(split, R, N, V, P)(texture2d_array input [[tex #endif } } -#elif V == x +#elif V == VX int x = gid.x; if (x < sp.vdim[0]) { out1.write(r, gid.xy, gid.z); @@ -75,4 +85,5 @@ kernel void FUNC(split, R, N, V, P)(texture2d_array input [[tex #endif } +#undef VV #endif diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal index ca51c3c49b5867fdf08a58d54802c0ba157663a2..914ab6d925234947db38df739da4af2dd076083e 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal @@ -25,48 +25,26 @@ struct SplitParam { int32_t vdim[4]; }; +#define VNORMAL 1 +#define VX 2 +#define VY 3 +#define VZ 4 + // only support split_{2, 3, 4}_{2, 3, 4}_y_{float, half} // only support split_{3, 4}_{2, 3, 4}_x_{float, half} -#define V y -// for R in 2..4 -#define R 3 - -// for N in 2..4 -#define N 2 - -#define P float -#include "Split.inc.metal" -#undef P -#define P half -#include "Split.inc.metal" -#undef P - -#undef N -// end for N - -#undef R -// end for R -#undef V - -#define V x -// for R in 3..4 -#define R 3 - -// for N in 2..4 -#define N 2 - -#define P float -#include "Split.inc.metal" -#undef P -#define P half -#include "Split.inc.metal" -#undef P - -#undef N -// end for N -#undef R -// end for R +//// ssd-ar: (R=3, N=2, V=y) +#define V VY + #define R 3 + #define N 2 + #define P float + #include "Split.inc.metal" + #undef P + #define P half + #include "Split.inc.metal" + #undef P + #undef N + #undef R #undef V