diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal index 497b3585c088f4b66192a17dea4b876f36ef2912..c4c9c7bbcff1a58635cd7d0cb0cd4e1485b0890a 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/ConcatKernel.metal @@ -67,8 +67,8 @@ struct ConcatParam { #undef R #undef V -// lens: (R=4, N=3, V=x) -#define V VX +// lens: (R=4, N=3, V=y) +#define V VY #define R 4 #define N 3 #define P float diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal index 08d0d2dfa503d9f4b4342f06aac0456d1e8eb7c6..10f5d1f6f9fd060df9e5a601f809fc411c1c9a0a 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/NearestInterpKernel.metal @@ -28,8 +28,8 @@ kernel void nearest_interp(texture2d_array inTexture [[te gid.z >= outTexture.get_array_size()) return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); float scale = param.scale; - uint x = uint(round(float(gid.x) / scale)); - uint y = uint(round(float(gid.y) / scale)); + uint x = uint(floor(float(gid.x) / scale)); + uint y = uint(floor(float(gid.y) / scale)); const float4 input = inTexture.read(uint2(x, y), gid.z); outTexture.write(input, gid.xy, gid.z); } @@ -43,8 +43,8 @@ kernel void nearest_interp_half(texture2d_array inTexture gid.z >= outTexture.get_array_size()) return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); float scale = param.scale; - uint x = uint(round(float(gid.x) / scale)); - uint y = uint(round(float(gid.y) / scale)); + uint x = uint(floor(float(gid.x) / scale)); + uint y = uint(floor(float(gid.y) / scale)); const half4 input = inTexture.read(uint2(x, y), gid.z); outTexture.write(input, gid.xy, gid.z); } diff --git a/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal b/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal index acf61fefcd1e5f5ffc665bba3c2b9a9db42b2dd5..9cc260a33f6b76ec2255347b056ab272e281effd 100644 --- a/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal +++ b/metal/paddle-mobile-metallib/paddle-mobile-metallib/SliceKernel.metal @@ -24,6 +24,8 @@ struct MetalSliceParam { short end1; short end2; short end3; + int iC; + int oC; }; kernel void slice(texture2d_array inTexture [[texture(0)]], @@ -35,9 +37,14 @@ kernel void slice(texture2d_array inTexture [[texture(0)] gid.z >= outTexture.get_array_size()) return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); float4 output; - for (int i = 0; i < 4; i++) { - int input_c = gid.z * 4 + i + param.start1; - int input_z = input_c / 4; + for (int i = 0; i < 4; ++i) { + int tmp = gid.z * 4 + i; + int output_c = tmp % param.oC; + int output_n = tmp / param.oC; + int c = output_c + param.start1; + tmp = output_n * param.iC + c; + int input_z = tmp / 4; + int input_c = tmp % 4; const float4 input = inTexture.read(gid.xy, input_z); output[i] = input[input_c % 4]; } @@ -52,12 +59,17 @@ kernel void slice_half(texture2d_array inTexture [[texture gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); - float4 output; - for (int i = 0; i < 4; i++) { - int input_c = gid.z * 4 + i + param.start1; - int input_z = input_c / 4; - const float4 input = float4(inTexture.read(gid.xy, input_z)); + half4 output; + for (int i = 0; i < 4; ++i) { + int tmp = gid.z * 4 + i; + int output_c = tmp % param.oC; + int output_n = tmp / param.oC; + int c = output_c + param.start1; + tmp = output_n * param.iC + c; + int input_z = tmp / 4; + int input_c = tmp % 4; + const half4 input = inTexture.read(gid.xy, input_z); output[i] = input[input_c % 4]; } - outTexture.write(half4(output), gid.xy, gid.z); + outTexture.write(output, gid.xy, gid.z); } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Common/MetalExtension.swift b/metal/paddle-mobile/paddle-mobile/Src/Common/MetalExtension.swift index 615ff5328230663e9a6ec9e9d54dceff3b9bf886..c09669137c9ea34eb1bf829493de70243a75ee0b 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Common/MetalExtension.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Common/MetalExtension.swift @@ -534,7 +534,7 @@ public extension MTLTexture { for c in 0..<4{ for h in 0..: Kernel, Computable{ required init(device: MTLDevice, param: ReshapeParam

, initContext: InitContext) throws { do { - try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision) + try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) } catch let error { throw error } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift index 565c974a86f14d7df3c73931c4cb212060e82bbf..bb2b7821db07d58d3e8ac22a22972f5e68430f7d 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift @@ -23,6 +23,8 @@ public struct SliceMetalParam { let end1: Int16 let end2: Int16 let end3: Int16 + let iC: Int32 + let oC: Int32 } class SliceKernel: Kernel, Computable { @@ -40,7 +42,7 @@ class SliceKernel: Kernel, Computable { required init(device: MTLDevice, param: SliceParam

, initContext: InitContext) throws { do { - try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision) } catch let error { throw error } @@ -60,7 +62,11 @@ class SliceKernel: Kernel, Computable { let end1 = ranges[1][1] let end2 = ranges[2][1] let end3 = ranges[3][1] - metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3) + + let iC = Int32(param.input.tensorDim[1]) + let oC = Int32(param.output.tensorDim[1]) + + metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3, iC: iC, oC: oC) if GlobalConfig.shared.computePrecision == .Float32 { super.init(device: device, inFunctionName: "slice", initContext: initContext) } else if GlobalConfig.shared.computePrecision == .Float16 {