From ad098db10d1f96fe8420f017c024d8867899dbc4 Mon Sep 17 00:00:00 2001 From: NazgulLee Date: Tue, 21 May 2019 19:22:16 +0800 Subject: [PATCH] method to adjust texture transpose (#1642) --- .../Src/Operators/Kernels/Base/Kernel.swift | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift index 1928bc9f24..7683a108b5 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/Base/Kernel.swift @@ -54,6 +54,52 @@ protocol KernelProtocol { _pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) } } + + func encodeTransposeInput(input: Texture, toTranspose: [Int], commandBuffer: MTLCommandBuffer, device: MTLDevice, initContext: InitContext) -> Texture? { + do { + let intermediateTexture = Texture(device: device, inDim: input.tensorDim) + try intermediateTexture.initTexture(device: device, inTranspose: toTranspose, computePrecision: GlobalConfig.shared.computePrecision) + + let irank = input.tensorDim.cout() + let orank = intermediateTexture.tensorDim.cout() + var funcName = "" + if GlobalConfig.shared.computePrecision == .Float32 { + funcName = "reshape_\(irank)_\(orank)_float" + } else if GlobalConfig.shared.computePrecision == .Float16 { + funcName = "reshape_\(irank)_\(orank)_half" + } else { + fatalError() + } + let intermediatePipeline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(input.metalTexture, index: 0) + encoder.setTexture(intermediateTexture.metalTexture, index: 1) + var id: [Int32] = [1, 1, 1, 1] + for i in 0...size, index: 0) + encoder.dispatch(computePipline: intermediatePipeline, outTexture: intermediateTexture.metalTexture) + encoder.endEncoding() + return intermediateTexture + } catch _ { + return nil + } + } } @objc public class Shape: NSObject { -- GitLab