diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift index 1928bc9f244a9a5f05deb4bd810fee937cdc5da9..7683a108b59e0fcf62efd4a9a0021223286fd7f1 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift @@ -54,6 +54,52 @@ protocol KernelProtocol { _pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) } } + + func encodeTransposeInput(input: Texture, toTranspose: [Int], commandBuffer: MTLCommandBuffer, device: MTLDevice, initContext: InitContext) -> Texture? { + do { + let intermediateTexture = Texture(device: device, inDim: input.tensorDim) + try intermediateTexture.initTexture(device: device, inTranspose: toTranspose, computePrecision: GlobalConfig.shared.computePrecision) + + let irank = input.tensorDim.cout() + let orank = intermediateTexture.tensorDim.cout() + var funcName = "" + if GlobalConfig.shared.computePrecision == .Float32 { + funcName = "reshape_\(irank)_\(orank)_float" + } else if GlobalConfig.shared.computePrecision == .Float16 { + funcName = "reshape_\(irank)_\(orank)_half" + } else { + fatalError() + } + let intermediatePipeline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(input.metalTexture, index: 0) + encoder.setTexture(intermediateTexture.metalTexture, index: 1) + var id: [Int32] = [1, 1, 1, 1] + for i in 0...size, index: 0) + encoder.dispatch(computePipline: intermediatePipeline, outTexture: intermediateTexture.metalTexture) + encoder.endEncoding() + return intermediateTexture + } catch _ { + return nil + } + } } @objc public class Shape: NSObject {