diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index b654a85bece77b8e96130f733652557c923dfa6f..a415bb9134655f543e563818df6b949f8049ff8c 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -21,6 +21,7 @@ 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; 4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; }; 4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA5214B5F6800D0F791 /* Shape.metal */; }; + 4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; @@ -134,6 +135,7 @@ 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = ""; }; 4AA1EAA5214B5F6800D0F791 /* Shape.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Shape.metal; sourceTree = ""; }; + 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.inc.metal; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = ""; }; @@ -456,6 +458,7 @@ 4AA1EA8F214664CD00D0F791 /* Split.metal */, 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */, 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */, + 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */, 4AF9287821341661005B6C3A /* Softmax.metal */, FCEB6849212F00DB00D2448E /* PreluKernel.metal */, FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */, @@ -609,6 +612,7 @@ 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */, FC33B0F02147659000714A93 /* MobileNet.swift in Sources */, FCEB684C212F093800D2448E /* PreluOp.swift in Sources */, + 4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */, FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */, @@ -833,7 +837,7 @@ "$(PROJECT_DIR)/paddle-mobile/CPU", ); MACH_O_TYPE = mh_dylib; - MTL_LANGUAGE_REVISION = UseDeploymentTarget; + MTL_LANGUAGE_REVISION = Metal12; PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile"; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SKIP_INSTALL = YES; @@ -869,7 +873,7 @@ "$(PROJECT_DIR)/paddle-mobile/CPU", ); MACH_O_TYPE = mh_dylib; - MTL_LANGUAGE_REVISION = UseDeploymentTarget; + MTL_LANGUAGE_REVISION = Metal12; PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile"; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; SKIP_INSTALL = YES; diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift index ab6a44187f75fdee9484026ec859347b6c6166dc..478b1a5f807f4387ce04fde46e6d96c3cfdd06ec 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift @@ -38,7 +38,7 @@ class BilinearInterpKernel: Kernel, Computable{ required init(device: MTLDevice, param: BilinearInterpParam

) { param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) if computePrecision == .Float32 { - super.init(device: device, inFunctionName: "bilinear_interp") + super.init(device: device, inFunctionName: "bilinear_interp_float") } else if computePrecision == .Float16 { super.init(device: device, inFunctionName: "bilinear_interp_half") } else { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift index e2d36049d6ee601857c8c6ae04862a21bf49b962..67e1cd9ab85c3c60d89846bab89ef10bbe513305 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift @@ -32,7 +32,7 @@ class SplitKernel: Kernel, Computable{ for i in 0...size, index: 0) + encoder.setBytes(&smp, length: MemoryLayout.size, index: 0) encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture) encoder.endEncoding() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.inc.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.inc.metal new file mode 100644 index 0000000000000000000000000000000000000000..cd6971bfda624a2c6b0bf9f4b51bf3e2a7c7195b --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.inc.metal @@ -0,0 +1,34 @@ +#ifdef P + +#define CONCAT2(a, b) a ## b +#define CONCAT2_(a, b) a ## _ ## b + +#define FUNC(f, p) CONCAT2_(f, p) +#define VECTOR(p, n) CONCAT2(p, n) + +kernel void FUNC(bilinear_interp, P)(texture2d_array input [[texture(0)]], + texture2d_array output [[texture(1)]], + constant bilinear_interp_param & pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + VECTOR(P, 4) r; + if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { + r = input.read(gid.xy, gid.z); + } else { + float w = gid.x * pm.ratio_w; + float h = gid.y * pm.ratio_h; + uint w0 = w, h0 = h; + uint w1 = w0 + 1, h1 = h0 + 1; + P w1lambda = w - w0, h1lambda = h - h0; + P w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; + if (w1 >= input.get_width()) w1 = w0; + if (h1 >= input.get_height()) h1 = h0; + VECTOR(P, 4) r0 = input.read(uint2(w0, h0), gid.z); + VECTOR(P, 4) r1 = input.read(uint2(w1, h0), gid.z); + VECTOR(P, 4) r2 = input.read(uint2(w0, h1), gid.z); + VECTOR(P, 4) r3 = input.read(uint2(w1, h1), gid.z); + r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); + } + output.write(r, gid.xy, gid.z); +} + +#endif diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal index 50c368e849b8e013dad7a4f374f5c4d3a1dd084c..c4eca3e1af7565b3dbef4646b80beb5a2725c714 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal @@ -22,54 +22,10 @@ struct bilinear_interp_param { float ratio_w; }; -kernel void bilinear_interp(texture2d_array input [[texture(0)]], - texture2d_array output [[texture(1)]], - constant bilinear_interp_param & pm [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - float4 r; - if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { - r = input.read(gid.xy, gid.z); - } else { - float w = gid.x * pm.ratio_w; - float h = gid.y * pm.ratio_h; - uint w0 = w, h0 = h; - uint w1 = w0 + 1, h1 = h0 + 1; - float w1lambda = w - w0, h1lambda = h - h0; - float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; - if (w1 >= input.get_width()) w1 = w0; - if (h1 >= input.get_height()) h1 = h0; - float4 r0 = input.read(uint2(w0, h0), gid.z); - float4 r1 = input.read(uint2(w1, h0), gid.z); - float4 r2 = input.read(uint2(w0, h1), gid.z); - float4 r3 = input.read(uint2(w1, h1), gid.z); - r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); - } - output.write(r, gid.xy, gid.z); -} +#define P float +#include "BilinearInterp.inc.metal" +#undef P -//kernel void bilinear_interp_half(texture2d_array input [[texture(0)]], -// texture2d_array output [[texture(1)]], -// constant bilinear_interp_param & pm [[buffer(0)]], -// uint3 gid [[thread_position_in_grid]]) { -// -// half4 r; -// if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { -// r = input.read(gid.xy, gid.z); -// } else { -// half w = gid.x * pm.ratio_w; -// half h = gid.y * pm.ratio_h; -// uint w0 = w, h0 = h; -// uint w1 = w0 + 1, h1 = h0 + 1; -// half w1lambda = w - w0, h1lambda = h - h0; -// half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; -// if (w1 >= input.get_width()) w1 = w0; -// if (h1 >= input.get_height()) h1 = h0; -// half4 r0 = input.read(uint2(w0, h0), gid.z); -// half4 r1 = input.read(uint2(w1, h0), gid.z); -// half4 r2 = input.read(uint2(w0, h1), gid.z); -// half4 r3 = input.read(uint2(w1, h1), gid.z); -// r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); -// } -// output.write(r, gid.xy, gid.z); -// output.write(r, gid.xy, gid.z); -//} +#define P half +#include "BilinearInterp.inc.metal" +#undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal index 4e1ab16cd7479f34fae578f7d914af061391fd12..bd45f635223f10cf0a0c4acd818c66996f30b2cf 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.inc.metal @@ -20,70 +20,89 @@ #define VV normal #endif +#if V == VY kernel void FUNC(split, R, N, VV, P)(texture2d_array input [[texture(0)]], texture2d_array out1 [[texture(1)]], texture2d_array out2 [[texture(2)]], #if N >= 3 texture2d_array out3 [[texture(3)]], -#endif +#endif // N >= 3 #if N >= 4 texture2d_array out4 [[texture(4)]], -#endif +#endif // N >= 4 constant SplitParam &sp [[buffer(0)]], uint3 gid [[thread_position_in_grid]]) { VECTOR(P, 4) r = input.read(gid.xy, gid.z); -#if V == VY int y = gid.y - sp.offset; if (y < sp.vdim[0]) { out1.write(r, gid.xy, gid.z); - } else { - y -= sp.vdim[0]; - if (y < sp.vdim[1]) { - out2.write(r, uint2(gid.x, y), gid.z); - } else { + return; + } + y -= sp.vdim[0]; + if (y < sp.vdim[1]) { + out2.write(r, uint2(gid.x, y), gid.z); + return; + } #if N >= 3 - y -= sp.vdim[1]; - if (y < sp.vdim[2]) { - out3.write(r, uint2(gid.x, y), gid.z); - } else { + y -= sp.vdim[1]; + if (y < sp.vdim[2]) { + out3.write(r, uint2(gid.x, y), gid.z); + return; + } +#endif // N >= 3 #if N >= 4 - y -= sp.vdim[2]; - if (y < sp.vdim[3]) { - out4.write(r, uint2(gid.x, y), gid.z); - } -#endif - } -#endif - } + y -= sp.vdim[2]; + if (y < sp.vdim[3]) { + out4.write(r, uint2(gid.x, y), gid.z); + return; } -#elif V == VX +#endif // N >= 4 +} +#endif // V == VY + + +#if V == VX +kernel void FUNC(split, R, N, VV, P)(texture2d_array input [[texture(0)]], + texture2d_array out1 [[texture(1)]], + texture2d_array out2 [[texture(2)]], +#if N >= 3 + texture2d_array out3 [[texture(3)]], +#endif // N >= 3 +#if N >= 4 + texture2d_array out4 [[texture(4)]], +#endif // N >= 4 + constant SplitParam &sp [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + VECTOR(P, 4) r = input.read(gid.xy, gid.z); int x = gid.x; if (x < sp.vdim[0]) { out1.write(r, gid.xy, gid.z); - } else { - x -= sp.vdim[0]; - if (x < sp.vdim[1]) { - out2.write(r, uint2(x, gid.y), gid.z); - } else { + return; + } + x -= sp.vdim[0]; + if (x < sp.vdim[1]) { + out2.write(r, uint2(x, gid.y), gid.z); + return; + } #if N >= 3 - x -= sp.vdim[1]; - if (x < sp.vdim[2]) { - out3.write(r, uint2(x, gid.y), gid.z); - } else { + x -= sp.vdim[1]; + if (x < sp.vdim[2]) { + out3.write(r, uint2(x, gid.y), gid.z); + return; + } +#endif // N >= 3 #if N >= 4 - x -= sp.vdim[2]; - if (x < sp.vdim[3]) { - out4.write(r, uint2(x, gid.y), gid.z); - } -#endif - } -#endif - } + x -= sp.vdim[2]; + if (x < sp.vdim[3]) { + out4.write(r, uint2(x, gid.y), gid.z); + return; } -#else -#endif +#endif // N >= 4 } +#endif // V == VX + + #undef VV #endif