提交 033bda03 编写于 作者: W WayneLiu 提交者: Yanzhan Yang

initTexture interface modeify to throw an exception (#1587)

上级 99cccf3b
...@@ -89,7 +89,7 @@ import Foundation ...@@ -89,7 +89,7 @@ import Foundation
} }
} }
open func updateProgram(program: Program) { open func updateProgram(program: Program) throws {
} }
} }
...@@ -103,7 +103,7 @@ import Foundation ...@@ -103,7 +103,7 @@ import Foundation
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!, initContext: initContext) executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!, initContext: initContext)
} }
net.updateProgram(program: program!) try net.updateProgram(program: program!)
} catch let error { } catch let error {
print(error) print(error)
return false return false
...@@ -181,14 +181,20 @@ import Foundation ...@@ -181,14 +181,20 @@ import Foundation
/// 更新输入维度, 针对可变长输入模型 /// 更新输入维度, 针对可变长输入模型
/// ///
/// - Parameter inDim: 输入维度 /// - Parameter inDim: 输入维度
@objc public func updateInputDim(inDim: Dim) { @objc public func updateInputDim(inDim: Dim) -> Bool {
if net.inputDim != inDim { if net.inputDim != inDim {
guard let inProgram = program else { guard let inProgram = program else {
fatalError(" need load first ") fatalError(" need load first ")
} }
net.inputDim = inDim net.inputDim = inDim
net.updateProgram(program: inProgram) do {
try net.updateProgram(program: inProgram)
} catch let error {
print(error)
return false
}
} }
return true
} }
public func scaleTexture(input: MTLTexture , complete: @escaping (MTLTexture) -> Void) { public func scaleTexture(input: MTLTexture , complete: @escaping (MTLTexture) -> Void) {
......
...@@ -22,6 +22,8 @@ public protocol SummableMultipliable: Equatable { ...@@ -22,6 +22,8 @@ public protocol SummableMultipliable: Equatable {
} }
public protocol PrecisionProtocol: SummableMultipliable{ public protocol PrecisionProtocol: SummableMultipliable{
// init(inFloat: Float32)
// init(inFloat16: Float16)
init<P: PrecisionProtocol>(_ inP: P) init<P: PrecisionProtocol>(_ inP: P)
static var bitSize: UInt { get } static var bitSize: UInt { get }
static func initializeValue() -> Self static func initializeValue() -> Self
...@@ -48,12 +50,27 @@ extension Float16: PrecisionProtocol { ...@@ -48,12 +50,27 @@ extension Float16: PrecisionProtocol {
default: default:
fatalError() fatalError()
} }
//
// fatalError()
// if P.bitSize == Float32.bitSize {
// self = Float16(inFloat: inP as! Float32)
// } else if P.bitSize == Float16.bitSize {
// self = inP as! Float16
// } else {
// fatalError()
// }
} }
public static var bitSize: UInt { public static var bitSize: UInt {
return 16 return 16
} }
// public init(inFloat16: Float16) {
// self = inFloat16
// }
// public init(inFloat: Float32) {
// self = Int16(inFloat)
// }
} }
extension Float32: PrecisionProtocol { extension Float32: PrecisionProtocol {
...@@ -75,8 +92,23 @@ extension Float32: PrecisionProtocol { ...@@ -75,8 +92,23 @@ extension Float32: PrecisionProtocol {
default: default:
fatalError() fatalError()
} }
// if P.bitSize == Float32.bitSize {
// self = inP as! Float32
// } else if P.bitSize == Float16.bitSize {
// self = Float32.init(inP as! Float16)
// } else {
// fatalError()
// }
} }
// public init(inFloat: Float32) {
// self = inFloat
// }
//
// public init(inFloat16: Float16) {
// self = Float32.init(inFloat16)
// }
//
public static var bitSize: UInt { public static var bitSize: UInt {
return 32 return 32
} }
......
...@@ -96,14 +96,16 @@ public class Texture: Tensorial { ...@@ -96,14 +96,16 @@ public class Texture: Tensorial {
return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3])) return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
} }
public func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: Precision = .Float16) { public func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: Precision = .Float16) throws {
transpose = inTranspose transpose = inTranspose
for i in 0..<(4 - tensorDim.cout()) { for i in 0..<(4 - tensorDim.cout()) {
if i != inTranspose[i] { if i != inTranspose[i] {
fatalError() // fatalError()
throw PaddleMobileError.loaderError(message: " dims error ")
} }
} }
let newDim = transpose.map { padToFourDim[$0] } let newDim = transpose.map { padToFourDim[$0] }
let newLayout = transpose.map { layout.layoutWithDim[$0] } let newLayout = transpose.map { layout.layoutWithDim[$0] }
...@@ -128,7 +130,8 @@ public class Texture: Tensorial { ...@@ -128,7 +130,8 @@ public class Texture: Tensorial {
tmpTextureDes.height = newDim[2] tmpTextureDes.height = newDim[2]
tmpTextureDes.arrayLength = 1 tmpTextureDes.arrayLength = 1
default: default:
fatalError("unreachable") // fatalError("unreachable")
throw PaddleMobileError.loaderError(message: " unreachable ")
} }
if computePrecision == .Float16 { if computePrecision == .Float16 {
...@@ -140,10 +143,13 @@ public class Texture: Tensorial { ...@@ -140,10 +143,13 @@ public class Texture: Tensorial {
tmpTextureDes.usage = [.shaderRead, .shaderWrite] tmpTextureDes.usage = [.shaderRead, .shaderWrite]
tmpTextureDes.storageMode = .shared tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil " guard let inTexture = device.makeTexture(descriptor: tmpTextureDes) else {
throw PaddleMobileError.loaderError(message: " create texture is nil ")
}
metalTexture = inTexture
} }
public func updateDims(inTensorDim: Dim, inDim: Dim) { public func updateDims(inTensorDim: Dim, inDim: Dim) throws {
var fourDim: Dim var fourDim: Dim
if inDim.cout() == 4 { if inDim.cout() == 4 {
fourDim = inDim fourDim = inDim
...@@ -155,7 +161,8 @@ public class Texture: Tensorial { ...@@ -155,7 +161,8 @@ public class Texture: Tensorial {
fourDimNum.append(contentsOf: inDim.dims) fourDimNum.append(contentsOf: inDim.dims)
fourDim = Dim.init(inDim: fourDimNum) fourDim = Dim.init(inDim: fourDimNum)
} else { } else {
fatalError(" not support ") // fatalError(" not support ")
throw PaddleMobileError.loaderError(message: " not support ")
} }
tensorDim = inTensorDim tensorDim = inTensorDim
......
...@@ -129,10 +129,10 @@ class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where ...@@ -129,10 +129,10 @@ class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where
paraInputs = opDesc.paraInputs paraInputs = opDesc.paraInputs
do { do {
para = try ParamType.init(opDesc:opDesc, inScope: inScope) para = try ParamType.init(opDesc:opDesc, inScope: inScope)
kernel = try KernelType.init(device: device, param: para, initContext: initContext)
} catch let error { } catch let error {
throw error throw error
} }
kernel = KernelType.init(device: device, param: para, initContext: initContext)
} }
typealias ParamType = ParameterType typealias ParamType = ParameterType
......
...@@ -28,7 +28,7 @@ public protocol Testable { ...@@ -28,7 +28,7 @@ public protocol Testable {
protocol Computable { protocol Computable {
associatedtype ParamType: OpParam associatedtype ParamType: OpParam
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice, param: ParamType, initContext: InitContext) init(device: MTLDevice, param: ParamType, initContext: InitContext) throws
} }
protocol KernelProtocol { protocol KernelProtocol {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable { class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: BatchNormParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: BatchNormParam<P>, initContext: InitContext) throws {
let count = param.variance.dim.numel() let count = param.variance.dim.numel()
let varianceP = param.variance.data.pointer let varianceP = param.variance.data.pointer
let meanP = param.mean.data.pointer let meanP = param.mean.data.pointer
...@@ -29,7 +29,13 @@ class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -29,7 +29,13 @@ class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable {
param.bias.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.bias.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.scale.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.scale.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm", initContext: initContext) super.init(device: device, inFunctionName: "batchnorm", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -41,8 +41,14 @@ class BilinearInterpKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -41,8 +41,14 @@ class BilinearInterpKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: BilinearInterpParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: BilinearInterpParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "bilinear_interp_float", initContext: initContext) super.init(device: device, inFunctionName: "bilinear_interp_float", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -32,8 +32,13 @@ class BoxcoderKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -32,8 +32,13 @@ class BoxcoderKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: BoxcoderParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: BoxcoderParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: GlobalConfig.shared.computePrecision) do {
try param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "boxcoder_float", initContext: initContext) super.init(device: device, inFunctionName: "boxcoder_float", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -52,8 +52,14 @@ class ConcatKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -52,8 +52,14 @@ class ConcatKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: ConcatParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConcatParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let orank = param.output.tensorDim.cout() let orank = param.output.tensorDim.cout()
let num = param.input.count let num = param.input.count
assert(num <= 6) assert(num <= 6)
......
...@@ -16,8 +16,14 @@ import Foundation ...@@ -16,8 +16,14 @@ import Foundation
class ConvAddAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable { class ConvAddAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddAddPreluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvAddAddPreluParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
......
...@@ -49,8 +49,14 @@ class ConvAddBatchNormReluKernel<P: PrecisionProtocol>: Kernel, Computable, Test ...@@ -49,8 +49,14 @@ class ConvAddBatchNormReluKernel<P: PrecisionProtocol>: Kernel, Computable, Test
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32) param.variance.initBuffer(device: device, precision: .Float32)
......
...@@ -102,9 +102,13 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -102,9 +102,13 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let identifyingKey: String = getUniqueKey() let identifyingKey: String = getUniqueKey()
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision) do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1)/2 - Int(param.paddings[1]) let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1)/2 - Int(param.paddings[1])
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1)/2 - Int(param.paddings[0]) let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1)/2 - Int(param.paddings[0])
......
...@@ -16,8 +16,14 @@ import Foundation ...@@ -16,8 +16,14 @@ import Foundation
class ConvAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable { class ConvAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddPreluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvAddPreluParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
......
...@@ -50,9 +50,14 @@ class ConvBNReluKernel<P: PrecisionProtocol>: Kernel, Computable, Testable { ...@@ -50,9 +50,14 @@ class ConvBNReluKernel<P: PrecisionProtocol>: Kernel, Computable, Testable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvBNReluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvBNReluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32) param.variance.initBuffer(device: device, precision: .Float32)
param.mean.initBuffer(device: device, precision: .Float32) param.mean.initBuffer(device: device, precision: .Float32)
......
...@@ -26,7 +26,7 @@ public struct MetalConvParam { ...@@ -26,7 +26,7 @@ public struct MetalConvParam {
class ConvKernel<P: PrecisionProtocol>: Kernel, Computable { class ConvKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvParam<P>, initContext: InitContext) throws {
param.filter.initBuffer(device: device, precision: Precision.Float32) param.filter.initBuffer(device: device, precision: Precision.Float32)
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_1x1", initContext: initContext) super.init(device: device, inFunctionName: "conv_1x1", initContext: initContext)
......
...@@ -30,8 +30,13 @@ struct MetalConvTransposeParam { ...@@ -30,8 +30,13 @@ struct MetalConvTransposeParam {
class ConvTransposeKernel<P: PrecisionProtocol>: Kernel, Computable{ class ConvTransposeKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: MetalConvTransposeParam! var metalParam: MetalConvTransposeParam!
required init(device: MTLDevice, param: ConvTransposeParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvTransposeParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, convertToNHWC: false, withTranspose: true) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, convertToNHWC: false, withTranspose: true)
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.stride == [2, 2] && param.stride == [2, 2] { if param.stride == [2, 2] && param.stride == [2, 2] {
......
...@@ -26,8 +26,13 @@ struct ElementwiseAddMetalParam { ...@@ -26,8 +26,13 @@ struct ElementwiseAddMetalParam {
class ElementwiseAddKernel<P: PrecisionProtocol>: Kernel, Computable { class ElementwiseAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ElementwiseAddParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = ElementwiseAddMetalParam.init() metalParam = ElementwiseAddMetalParam.init()
......
...@@ -17,8 +17,14 @@ import Foundation ...@@ -17,8 +17,14 @@ import Foundation
class ElementwiseAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable { class ElementwiseAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
metalParam = ElementwiseAddMetalParam.init() metalParam = ElementwiseAddMetalParam.init()
......
...@@ -16,7 +16,7 @@ import Foundation ...@@ -16,7 +16,7 @@ import Foundation
class FetchKernel<P: PrecisionProtocol>: Kernel, Computable { class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) throws {
param.output.initBuffer(device: device) param.output.initBuffer(device: device)
if GlobalConfig.shared.computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.input.transpose == [0, 2, 3, 1] { if param.input.transpose == [0, 2, 3, 1] {
......
...@@ -26,8 +26,14 @@ class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -26,8 +26,14 @@ class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: FlattenMetalParam var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: FlattenParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var id: [Int32] = [1, 1, 1, 1] var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() { for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i]) id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
......
...@@ -17,7 +17,7 @@ import Foundation ...@@ -17,7 +17,7 @@ import Foundation
class MulticlassNMSKernel<P: PrecisionProtocol>: Kernel, Computable{ class MulticlassNMSKernel<P: PrecisionProtocol>: Kernel, Computable{
let pipline1: MTLComputePipelineState let pipline1: MTLComputePipelineState
required init(device: MTLDevice, param: MulticlassNMSParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: MulticlassNMSParam<P>, initContext: InitContext) throws {
param.middleOutput.initBuffer(device: device) param.middleOutput.initBuffer(device: device)
param.bboxOutput.initBuffer(device: device) param.bboxOutput.initBuffer(device: device)
......
...@@ -26,8 +26,13 @@ struct PoolMetalParam { ...@@ -26,8 +26,13 @@ struct PoolMetalParam {
class PoolKernel<P: PrecisionProtocol>: Kernel, Computable{ class PoolKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: PoolMetalParam var metalParam: PoolMetalParam
required init(device: MTLDevice, param: PoolParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: PoolParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var poolType: Int32 var poolType: Int32
switch param.poolType { switch param.poolType {
......
...@@ -15,9 +15,15 @@ ...@@ -15,9 +15,15 @@
import Foundation import Foundation
class PreluKernel<P: PrecisionProtocol>: Kernel, Computable{ class PreluKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: PreluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: PreluParam<P>, initContext: InitContext) throws {
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "prelu_channel", initContext: initContext) super.init(device: device, inFunctionName: "prelu_channel", initContext: initContext)
......
...@@ -32,15 +32,19 @@ struct PriorBoxMetalParam { ...@@ -32,15 +32,19 @@ struct PriorBoxMetalParam {
class PriorBoxKernel<P: PrecisionProtocol>: Kernel, Computable{ class PriorBoxKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: PriorBoxMetalParam! var metalParam: PriorBoxMetalParam!
required init(device: MTLDevice, param: PriorBoxParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: PriorBoxParam<P>, initContext: InitContext) throws {
let originDim = param.output.tensorDim; let originDim = param.output.tensorDim;
param.output.tensorDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]]) param.output.tensorDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.padToFourDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]]) param.output.padToFourDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision) do {
param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: GlobalConfig.shared.computePrecision) try param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision)
try param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.min_max_aspect_ratios_order { if param.min_max_aspect_ratios_order {
......
...@@ -25,7 +25,7 @@ class ReluKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -25,7 +25,7 @@ class ReluKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: ReluParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ReluParam<P>, initContext: InitContext) throws {
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "relu", initContext: initContext) super.init(device: device, inFunctionName: "relu", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -31,8 +31,14 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -31,8 +31,14 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: ReshapeMetalParam var metalParam: ReshapeMetalParam
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var id: [Int32] = [1, 1, 1, 1] var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() { for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i]) id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
......
...@@ -20,8 +20,13 @@ struct ResizeBilinearMetalParam { ...@@ -20,8 +20,13 @@ struct ResizeBilinearMetalParam {
} }
class ResizeBilinearKernel<P: PrecisionProtocol>: Kernel, Computable{ class ResizeBilinearKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: ResizeBilinearParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ResizeBilinearParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear", initContext: initContext) super.init(device: device, inFunctionName: "resize_bilinear", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -28,8 +28,14 @@ class ShapeKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -28,8 +28,14 @@ class ShapeKernel<P: PrecisionProtocol>: Kernel, Computable{
// encoder.endEncoding() // encoder.endEncoding()
} }
required init(device: MTLDevice, param: ShapeParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ShapeParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "shape", initContext: initContext) super.init(device: device, inFunctionName: "shape", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
...@@ -22,8 +22,14 @@ struct SoftmaxMetalParam { ...@@ -22,8 +22,14 @@ struct SoftmaxMetalParam {
class SoftmaxKernel<P: PrecisionProtocol>: Kernel, Computable{ class SoftmaxKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: SoftmaxMetalParam var metalParam: SoftmaxMetalParam
required init(device: MTLDevice, param: SoftmaxParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: SoftmaxParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = SoftmaxMetalParam.init( metalParam = SoftmaxMetalParam.init(
N: Int32(param.input.tensorDim[0]), N: Int32(param.input.tensorDim[0]),
K: Int32(param.input.tensorDim[1]) K: Int32(param.input.tensorDim[1])
......
...@@ -37,13 +37,17 @@ class SplitKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -37,13 +37,17 @@ class SplitKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: SplitParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: SplitParam<P>, initContext: InitContext) throws {
// param.output.initTexture(device: device, computePrecision: computePrecision) // param.output.initTexture(device: device, computePrecision: computePrecision)
let num = param.outputList.count let num = param.outputList.count
let rank = param.input.tensorDim.cout() let rank = param.input.tensorDim.cout()
assert(num >= 2 && num <= 4) assert(num >= 2 && num <= 4)
for output in param.outputList { for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) do {
try output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
} }
smp = SplitMetalParam.init() smp = SplitMetalParam.init()
smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3])) smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3]))
......
...@@ -24,8 +24,14 @@ struct Texture2DTo2DArrayParam { ...@@ -24,8 +24,14 @@ struct Texture2DTo2DArrayParam {
class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{ class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "texture2d_to_2d_array_half", initContext: initContext) super.init(device: device, inFunctionName: "texture2d_to_2d_array_half", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
......
...@@ -22,8 +22,14 @@ struct TransposeMetalParam { ...@@ -22,8 +22,14 @@ struct TransposeMetalParam {
class TransposeKernel<P: PrecisionProtocol>: Kernel, Computable { class TransposeKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: TransposeMetalParam = TransposeMetalParam.init() var metalParam: TransposeMetalParam = TransposeMetalParam.init()
required init(device: MTLDevice, param: TransposeParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: TransposeParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let rank = param.input.tensorDim.cout() let rank = param.input.tensorDim.cout()
var axis: [Int] = [0, 1, 2, 3] var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count { for i in 0..<param.axis.count {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册