/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ import Foundation import MetalKit import CoreMedia protocol Tensorial: Variant { var dim: Dim { get set } func numel() -> Int var layout: DataLayout { get } } extension Tensorial { func numel() -> Int { return dim.numel() } } class DataConverter { func convert(from: UnsafeMutablePointer

, to: UnsafeMutablePointer

, fromDim: Dim) { fatalError(" need imp") } func getToDim(fromDim: Dim, layout: DataLayout) -> (dim: Dim, layout: DataLayout) { fatalError(" need imp") } } /// [ outputChannels ][ inputChannels ][ kernelHeight ][ kernelWidth ] -> /// [ outputChannels ][ kernelHeight ][ kernelWidth ][ inputChannels ] class MPSPointerConverter: DataConverter

{ /// [ outputChannels ][ inputChannels ][ kernelHeight ][ kernelWidth ] -> /// [ outputChannels ][ kernelHeight ][ kernelWidth ][ inputChannels ] /// - Parameters: /// - from: from pointer /// - to: to pointer override func convert(from: UnsafeMutablePointer

, to: UnsafeMutablePointer

, fromDim: Dim) { let outputChannels = fromDim[0] let inputChannels = fromDim[1] let kernelHeight = fromDim[2] let kernelWidth = fromDim[3] for outChannel in 0.. (dim: Dim, layout: DataLayout) { if layout != DataLayout.NCHW() { fatalError("not support") } let outputChannels = fromDim[0] let inputChannels = fromDim[1] let kernelHeight = fromDim[2] let kernelWidth = fromDim[3] let toDim = Dim.init(inDim: [outputChannels, kernelHeight, kernelWidth, inputChannels]) return (dim: toDim, layout: DataLayout.NHWC()) } } class Tensor: Tensorial { var data: Data var dim: Dim /// 模型中的维度: 未经过转换 paddle 模型维度为 N C H W var tensorDim: Dim var buffer: MTLBuffer! private(set) var layout: DataLayout class Data { private var released = false let count: Int let size: Int init(inCount: Int, inPointer: UnsafeMutablePointer

) { count = inCount size = inCount * MemoryLayout

.size pointer = inPointer } internal private(set) var pointer: UnsafeMutablePointer

subscript(index: Int) -> P { get { return pointer[index] } set { pointer[index] = newValue } } func release() { if !released { pointer.deinitialize(count: count) pointer.deallocate() released = true } } deinit { if !released { pointer.deinitialize(count: count) pointer.deallocate() released = true } } } init(inDim: Dim, inLayout: DataLayout = DataLayout.NCHW()) { tensorDim = inDim dim = inDim let pointer = UnsafeMutablePointer

.allocate(capacity: inDim.numel()) data = Data.init(inCount: inDim.numel(), inPointer: pointer) layout = inLayout } func convert(converter: DataConverter

) -> UnsafeMutablePointer

{ let to = UnsafeMutablePointer

.allocate(capacity: numel()) converter.convert(from: data.pointer, to: to, fromDim: dim) data = Data.init(inCount: numel(), inPointer: to) let dimAndLayout = converter.getToDim(fromDim: dim, layout: layout) dim = dimAndLayout.dim layout = dimAndLayout.layout return to } func convert(to: DataLayout) { guard to != layout else { return } guard dim.cout() == 4 else { return } guard layout == DataLayout.NCHW() && to == DataLayout.NHWC() else { // other not support return } let newPointer = UnsafeMutablePointer

.allocate(capacity: numel()) if layout == DataLayout.NCHW() { NCHW2NHWC(newPtr: newPointer) } data.release() data = Data.init(inCount: data.count, inPointer: newPointer) layout = to } func initBuffer(device: MTLDevice, precision computePrecision: Precision = .Float16, padWhenOneC: Bool = false, convertToNHWC: Bool = true, withTranspose: Bool = false) { if convertToNHWC { convert(to: DataLayout.NHWC()) } if P.precisionType == .Float16 && computePrecision == .Float32{ fatalError(" 不支持: 16位模型不能按照 32 位进行运算") } if withTranspose { let transposePointer = UnsafeMutablePointer

.allocate(capacity: numel()) let n = dim[0] let hwc = numel()/n for j in 0...stride) case .Float32: switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: data.pointer as! UnsafeMutablePointer, output: buffer.contents(), count: count) } } } else if C == 1 && !padWhenOneC { buffer = device.makeBuffer(length: numel() * precisionSize) switch P.precisionType { case .Float16: buffer?.contents().copyMemory(from: data.pointer, byteCount: numel() * MemoryLayout

.stride) case .Float32: switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: data.pointer, byteCount: numel() * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: data.pointer as! UnsafeMutablePointer, output: buffer.contents(), count: numel()) } } } else { buffer = device.makeBuffer(length: count * precisionSize) let convertedPointer = UnsafeMutablePointer

.allocate(capacity: count) var tmpPointer = data.pointer var dstPtr = convertedPointer for _ in 0...stride) case .Float32: switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: convertedPointer, byteCount: count * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: convertedPointer as! UnsafeMutablePointer, output: buffer.contents(), count: count) } } convertedPointer.deinitialize(count: count) convertedPointer.deallocate() } } else { let C = dim[3] let cSlices = (C + 3) / 4 let paddedC = cSlices * 4 let count = paddedC * dim[0] * dim[1] * dim[2] if C == paddedC { buffer = device.makeBuffer(length: count * precisionSize) switch P.precisionType { case .Float16: buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout

.stride) case .Float32: switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: data.pointer as! UnsafeMutablePointer, output: buffer.contents(), count: count) } } } else if C == 1 { fatalError(" not support ") } else { buffer = device.makeBuffer(length: count * precisionSize) let convertedPointer = UnsafeMutablePointer

.allocate(capacity: count) var tmpPointer = data.pointer var dstPtr = convertedPointer for _ in 0...stride) case .Float32: // 模型精度为 16 位 switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: convertedPointer, byteCount: count * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: convertedPointer as! UnsafeMutablePointer, output: buffer.contents(), count: count) } } convertedPointer.deinitialize(count: count) convertedPointer.deallocate() } } } else if dim.cout() == 1 { let num = ((numel() + 3) / 4) * 4 buffer = device.makeBuffer(length: num * precisionSize) switch P.precisionType { case .Float16: buffer?.contents().copyMemory(from: data.pointer, byteCount: num * MemoryLayout

.stride) case .Float32: switch computePrecision { case .Float32: buffer?.contents().copyMemory(from: data.pointer, byteCount: num * MemoryLayout

.stride) case .Float16: float32ToFloat16(input: data.pointer as! UnsafeMutablePointer, output: buffer.contents(), count: num) } } } else { fatalError(" not support !") } //TODO: release data.release() } var n: Int { get { if dim.cout() == 4 { if layout == DataLayout.NCHW() { return dim[0] } else if layout == DataLayout.NHWC() { return dim[0] } else { fatalError(" unsupport ") } } else { fatalError() } } } var width: Int { get { if dim.cout() == 4 { if layout == DataLayout.NHWC() { return dim[2] } else if layout == DataLayout.NCHW() { return dim[3] } else { fatalError(" unsupport ") } } else { fatalError() } } } var height: Int { get { if dim.cout() == 4 { if layout == DataLayout.NHWC() { return dim[1] } else if layout == DataLayout.NCHW() { return dim[2] } else { fatalError(" unsupport ") } } else { fatalError() } } } var channel: Int { get { if dim.cout() == 4 { if layout == DataLayout.NHWC() { return dim[3] } else if layout == DataLayout.NCHW() { return dim[1] } else { fatalError(" unsupport ") } } else { fatalError() } } } func NCHW2NHWC(newPtr: UnsafeMutablePointer

) { let N = dim[0] let C = dim[1] let H = dim[2] let W = dim[3] let HXW = H * W let CXHXW = C * H * W var index: Int = 0 for n in 0...size { str += " \(buffer.contents().assumingMemoryBound(to: P.self)[i])" } return str } func logDataPointer(header: String = "") { print(header) var str = "" str += "data count: \(data.count) \n" str += "dim: \(dim) \n" for i in 0..