diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift index cb6944980a7ac10808748acd6d698d26f5d205ac..32c61e9dcdb9eacc273d4712ca58f1ecdab36348 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift @@ -114,6 +114,10 @@ class ConvAddKernel: Kernel, Computable { if type(of: self).isWinoGrad(functionName: functionName) { shouldUseMPS = false } + let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1] + if !isDepthWise && param.groups > 1 { + shouldUseMPS = false + } if shouldUseMPS { super.init(device: device, inFunctionName: nil, initContext: initContext) setupWithMPS(device: device, param: param) diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ReshapeKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ReshapeKernel.swift index 11864698c36c0104b115862f1814fb8703c6dd2f..03cc4cad3c0926eac1a72582ab3f80f81b7cbc56 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ReshapeKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ReshapeKernel.swift @@ -34,7 +34,7 @@ class ReshapeKernel: Kernel, Computable{ required init(device: MTLDevice, param: ReshapeParam

, initContext: InitContext) throws { do { - try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) + try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision) } catch let error { throw error } diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift index bb2b7821db07d58d3e8ac22a22972f5e68430f7d..2f241722126232e6048007b7f18005ba29e3ac4b 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/SliceKernel.swift @@ -29,11 +29,24 @@ public struct SliceMetalParam { class SliceKernel: Kernel, Computable { var metalParam: SliceMetalParam + var device: MTLDevice? + var initContext: InitContext? + func compute(commandBuffer: MTLCommandBuffer, param: SliceParam

) throws { + let expectedTranspose = [0, 2, 3, 1] + var input = param.input + if param.input.transpose != expectedTranspose { + if let device = device, let initContext = initContext, let transposedInput = encodeTransposeInput(input: param.input, toTranspose: expectedTranspose, commandBuffer: commandBuffer, device: device, initContext: initContext) { + input = transposedInput + } else { + print("input transpose failed in slice kernel") + } + } + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encode is nil") } - encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(input.metalTexture, index: 0) encoder.setTexture(param.output.metalTexture, index: 1) encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) @@ -72,7 +85,9 @@ class SliceKernel: Kernel, Computable { } else if GlobalConfig.shared.computePrecision == .Float16 { super.init(device: device, inFunctionName: "slice_half", initContext: initContext) } else { - fatalError() + fatalError("unknown computePrecision") } + self.device = device + self.initContext = initContext } }