提交 ad098db1 编写于 作者: N NazgulLee 提交者: Yanzhan Yang

method to adjust texture transpose (#1642)

上级 6d761d3c
...@@ -54,6 +54,52 @@ protocol KernelProtocol { ...@@ -54,6 +54,52 @@ protocol KernelProtocol {
_pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath) _pipline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
} }
} }
func encodeTransposeInput(input: Texture, toTranspose: [Int], commandBuffer: MTLCommandBuffer, device: MTLDevice, initContext: InitContext) -> Texture? {
do {
let intermediateTexture = Texture(device: device, inDim: input.tensorDim)
try intermediateTexture.initTexture(device: device, inTranspose: toTranspose, computePrecision: GlobalConfig.shared.computePrecision)
let irank = input.tensorDim.cout()
let orank = intermediateTexture.tensorDim.cout()
var funcName = ""
if GlobalConfig.shared.computePrecision == .Float32 {
funcName = "reshape_\(irank)_\(orank)_float"
} else if GlobalConfig.shared.computePrecision == .Float16 {
funcName = "reshape_\(irank)_\(orank)_half"
} else {
fatalError()
}
let intermediatePipeline = device.pipeLine(funcName: funcName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(input.metalTexture, index: 0)
encoder.setTexture(intermediateTexture.metalTexture, index: 1)
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<input.tensorDim.cout() {
id[4-input.tensorDim.cout()+i] = Int32(input.tensorDim[i])
}
let it: [Int32] = input.transpose.map { Int32($0) }
var od: [Int32] = [1, 1, 1, 1]
for i in 0..<intermediateTexture.tensorDim.cout() {
od[4-intermediateTexture.tensorDim.cout()+i] = Int32(intermediateTexture.tensorDim[i])
}
let ot: [Int32] = intermediateTexture.transpose.map { Int32($0) }
var reshapeMetalParam = ReshapeMetalParam.init(
idim: (id[0], id[1], id[2], id[3]),
itrans: (it[0], it[1], it[2], it[3]),
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
)
encoder.setBytes(&reshapeMetalParam, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: intermediatePipeline, outTexture: intermediateTexture.metalTexture)
encoder.endEncoding()
return intermediateTexture
} catch _ {
return nil
}
}
} }
@objc public class Shape: NSObject { @objc public class Shape: NSObject {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册