提交 033bda03 编写于 作者: W WayneLiu 提交者: Yanzhan Yang

initTexture interface modeify to throw an exception (#1587)

上级 99cccf3b
......@@ -89,7 +89,7 @@ import Foundation
}
}
open func updateProgram(program: Program) {
open func updateProgram(program: Program) throws {
}
}
......@@ -103,7 +103,7 @@ import Foundation
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!, initContext: initContext)
}
net.updateProgram(program: program!)
try net.updateProgram(program: program!)
} catch let error {
print(error)
return false
......@@ -181,14 +181,20 @@ import Foundation
/// 更新输入维度, 针对可变长输入模型
///
/// - Parameter inDim: 输入维度
@objc public func updateInputDim(inDim: Dim) {
@objc public func updateInputDim(inDim: Dim) -> Bool {
if net.inputDim != inDim {
guard let inProgram = program else {
fatalError(" need load first ")
}
net.inputDim = inDim
net.updateProgram(program: inProgram)
do {
try net.updateProgram(program: inProgram)
} catch let error {
print(error)
return false
}
}
return true
}
public func scaleTexture(input: MTLTexture , complete: @escaping (MTLTexture) -> Void) {
......
......@@ -22,6 +22,8 @@ public protocol SummableMultipliable: Equatable {
}
public protocol PrecisionProtocol: SummableMultipliable{
// init(inFloat: Float32)
// init(inFloat16: Float16)
init<P: PrecisionProtocol>(_ inP: P)
static var bitSize: UInt { get }
static func initializeValue() -> Self
......@@ -48,12 +50,27 @@ extension Float16: PrecisionProtocol {
default:
fatalError()
}
//
// fatalError()
// if P.bitSize == Float32.bitSize {
// self = Float16(inFloat: inP as! Float32)
// } else if P.bitSize == Float16.bitSize {
// self = inP as! Float16
// } else {
// fatalError()
// }
}
public static var bitSize: UInt {
return 16
}
// public init(inFloat16: Float16) {
// self = inFloat16
// }
// public init(inFloat: Float32) {
// self = Int16(inFloat)
// }
}
extension Float32: PrecisionProtocol {
......@@ -75,8 +92,23 @@ extension Float32: PrecisionProtocol {
default:
fatalError()
}
// if P.bitSize == Float32.bitSize {
// self = inP as! Float32
// } else if P.bitSize == Float16.bitSize {
// self = Float32.init(inP as! Float16)
// } else {
// fatalError()
// }
}
// public init(inFloat: Float32) {
// self = inFloat
// }
//
// public init(inFloat16: Float16) {
// self = Float32.init(inFloat16)
// }
//
public static var bitSize: UInt {
return 32
}
......
......@@ -96,14 +96,16 @@ public class Texture: Tensorial {
return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
}
public func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: Precision = .Float16) {
public func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: Precision = .Float16) throws {
transpose = inTranspose
for i in 0..<(4 - tensorDim.cout()) {
if i != inTranspose[i] {
fatalError()
// fatalError()
throw PaddleMobileError.loaderError(message: " dims error ")
}
}
let newDim = transpose.map { padToFourDim[$0] }
let newLayout = transpose.map { layout.layoutWithDim[$0] }
......@@ -128,7 +130,8 @@ public class Texture: Tensorial {
tmpTextureDes.height = newDim[2]
tmpTextureDes.arrayLength = 1
default:
fatalError("unreachable")
// fatalError("unreachable")
throw PaddleMobileError.loaderError(message: " unreachable ")
}
if computePrecision == .Float16 {
......@@ -140,10 +143,13 @@ public class Texture: Tensorial {
tmpTextureDes.usage = [.shaderRead, .shaderWrite]
tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
guard let inTexture = device.makeTexture(descriptor: tmpTextureDes) else {
throw PaddleMobileError.loaderError(message: " create texture is nil ")
}
metalTexture = inTexture
}
public func updateDims(inTensorDim: Dim, inDim: Dim) {
public func updateDims(inTensorDim: Dim, inDim: Dim) throws {
var fourDim: Dim
if inDim.cout() == 4 {
fourDim = inDim
......@@ -155,7 +161,8 @@ public class Texture: Tensorial {
fourDimNum.append(contentsOf: inDim.dims)
fourDim = Dim.init(inDim: fourDimNum)
} else {
fatalError(" not support ")
// fatalError(" not support ")
throw PaddleMobileError.loaderError(message: " not support ")
}
tensorDim = inTensorDim
......
......@@ -129,10 +129,10 @@ class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where
paraInputs = opDesc.paraInputs
do {
para = try ParamType.init(opDesc:opDesc, inScope: inScope)
kernel = try KernelType.init(device: device, param: para, initContext: initContext)
} catch let error {
throw error
}
kernel = KernelType.init(device: device, param: para, initContext: initContext)
}
typealias ParamType = ParameterType
......
......@@ -28,7 +28,7 @@ public protocol Testable {
protocol Computable {
associatedtype ParamType: OpParam
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice, param: ParamType, initContext: InitContext)
init(device: MTLDevice, param: ParamType, initContext: InitContext) throws
}
protocol KernelProtocol {
......
......@@ -15,7 +15,7 @@
import Foundation
class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: BatchNormParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: BatchNormParam<P>, initContext: InitContext) throws {
let count = param.variance.dim.numel()
let varianceP = param.variance.data.pointer
let meanP = param.mean.data.pointer
......@@ -29,7 +29,13 @@ class BatchNormKernel<P: PrecisionProtocol>: Kernel, Computable {
param.bias.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.scale.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -41,8 +41,14 @@ class BilinearInterpKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
required init(device: MTLDevice, param: BilinearInterpParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: BilinearInterpParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "bilinear_interp_float", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -32,8 +32,13 @@ class BoxcoderKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
required init(device: MTLDevice, param: BoxcoderParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: BoxcoderParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "boxcoder_float", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -52,8 +52,14 @@ class ConcatKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
required init(device: MTLDevice, param: ConcatParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ConcatParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let orank = param.output.tensorDim.cout()
let num = param.input.count
assert(num <= 6)
......
......@@ -16,8 +16,14 @@ import Foundation
class ConvAddAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ConvAddAddPreluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
......
......@@ -49,8 +49,14 @@ class ConvAddBatchNormReluKernel<P: PrecisionProtocol>: Kernel, Computable, Test
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32)
......
......@@ -102,9 +102,13 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let identifyingKey: String = getUniqueKey()
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) throws {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1)/2 - Int(param.paddings[1])
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1)/2 - Int(param.paddings[0])
......
......@@ -16,8 +16,14 @@ import Foundation
class ConvAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ConvAddPreluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
......
......@@ -50,9 +50,14 @@ class ConvBNReluKernel<P: PrecisionProtocol>: Kernel, Computable, Testable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvBNReluParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: ConvBNReluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32)
param.mean.initBuffer(device: device, precision: .Float32)
......
......@@ -26,7 +26,7 @@ public struct MetalConvParam {
class ConvKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: ConvParam<P>, initContext: InitContext) throws {
param.filter.initBuffer(device: device, precision: Precision.Float32)
if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_1x1", initContext: initContext)
......
......@@ -30,8 +30,13 @@ struct MetalConvTransposeParam {
class ConvTransposeKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: MetalConvTransposeParam!
required init(device: MTLDevice, param: ConvTransposeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ConvTransposeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, convertToNHWC: false, withTranspose: true)
if GlobalConfig.shared.computePrecision == .Float32 {
if param.stride == [2, 2] && param.stride == [2, 2] {
......
......@@ -26,8 +26,13 @@ struct ElementwiseAddMetalParam {
class ElementwiseAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ElementwiseAddParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = ElementwiseAddMetalParam.init()
......
......@@ -17,8 +17,14 @@ import Foundation
class ElementwiseAddPreluKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
metalParam = ElementwiseAddMetalParam.init()
......
......@@ -16,7 +16,7 @@ import Foundation
class FetchKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) throws {
param.output.initBuffer(device: device)
if GlobalConfig.shared.computePrecision == .Float16 {
if param.input.transpose == [0, 2, 3, 1] {
......
......@@ -26,8 +26,14 @@ class FlattenKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: FlattenParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
......
......@@ -17,7 +17,7 @@ import Foundation
class MulticlassNMSKernel<P: PrecisionProtocol>: Kernel, Computable{
let pipline1: MTLComputePipelineState
required init(device: MTLDevice, param: MulticlassNMSParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: MulticlassNMSParam<P>, initContext: InitContext) throws {
param.middleOutput.initBuffer(device: device)
param.bboxOutput.initBuffer(device: device)
......
......@@ -26,8 +26,13 @@ struct PoolMetalParam {
class PoolKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: PoolMetalParam
required init(device: MTLDevice, param: PoolParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: PoolParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var poolType: Int32
switch param.poolType {
......
......@@ -15,9 +15,15 @@
import Foundation
class PreluKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: PreluParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: PreluParam<P>, initContext: InitContext) throws {
param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
if param.mode == "channel" {
super.init(device: device, inFunctionName: "prelu_channel", initContext: initContext)
......
......@@ -32,15 +32,19 @@ struct PriorBoxMetalParam {
class PriorBoxKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: PriorBoxMetalParam!
required init(device: MTLDevice, param: PriorBoxParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: PriorBoxParam<P>, initContext: InitContext) throws {
let originDim = param.output.tensorDim;
param.output.tensorDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.padToFourDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision)
param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: GlobalConfig.shared.computePrecision)
do {
try param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision)
try param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
if param.min_max_aspect_ratios_order {
......
......@@ -25,7 +25,7 @@ class ReluKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
required init(device: MTLDevice, param: ReluParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: ReluParam<P>, initContext: InitContext) throws {
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "relu", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -31,8 +31,14 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: ReshapeMetalParam
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
......
......@@ -20,8 +20,13 @@ struct ResizeBilinearMetalParam {
}
class ResizeBilinearKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: ResizeBilinearParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ResizeBilinearParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -28,8 +28,14 @@ class ShapeKernel<P: PrecisionProtocol>: Kernel, Computable{
// encoder.endEncoding()
}
required init(device: MTLDevice, param: ShapeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: ShapeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "shape", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
......@@ -22,8 +22,14 @@ struct SoftmaxMetalParam {
class SoftmaxKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: SoftmaxMetalParam
required init(device: MTLDevice, param: SoftmaxParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: SoftmaxParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = SoftmaxMetalParam.init(
N: Int32(param.input.tensorDim[0]),
K: Int32(param.input.tensorDim[1])
......
......@@ -37,13 +37,17 @@ class SplitKernel<P: PrecisionProtocol>: Kernel, Computable{
encoder.endEncoding()
}
required init(device: MTLDevice, param: SplitParam<P>, initContext: InitContext) {
required init(device: MTLDevice, param: SplitParam<P>, initContext: InitContext) throws {
// param.output.initTexture(device: device, computePrecision: computePrecision)
let num = param.outputList.count
let rank = param.input.tensorDim.cout()
assert(num >= 2 && num <= 4)
for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
do {
try output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
}
smp = SplitMetalParam.init()
smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3]))
......
......@@ -24,8 +24,14 @@ struct Texture2DTo2DArrayParam {
class Texture2DTo2DArrayKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "texture2d_to_2d_array_half", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float32 {
......
......@@ -22,8 +22,14 @@ struct TransposeMetalParam {
class TransposeKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: TransposeMetalParam = TransposeMetalParam.init()
required init(device: MTLDevice, param: TransposeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
required init(device: MTLDevice, param: TransposeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
let rank = param.input.tensorDim.cout()
var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册