提交 e71320da 编写于 作者: D dolphin8

batchnorm

上级 2bcb1135
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; }; 4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; };
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; };
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; };
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */; }; 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; };
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */; }; 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; };
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
...@@ -126,8 +127,9 @@ ...@@ -126,8 +127,9 @@
4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; }; 4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; };
4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; }; 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; };
4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; };
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.metal.inc; sourceTree = "<group>"; }; 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = "<group>"; };
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.metal.inc; sourceTree = "<group>"; }; 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = "<group>"; };
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
...@@ -395,6 +397,7 @@ ...@@ -395,6 +397,7 @@
FCD04E6720F315020007374F /* PoolKernel.swift */, FCD04E6720F315020007374F /* PoolKernel.swift */,
FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */, FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */,
FCD04E6F20F31B720007374F /* ReshapeKernel.swift */, FCD04E6F20F31B720007374F /* ReshapeKernel.swift */,
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */,
FCD04E7320F3437E0007374F /* ConvAddKernel.swift */, FCD04E7320F3437E0007374F /* ConvAddKernel.swift */,
FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */, FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */,
FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */, FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */,
...@@ -442,7 +445,7 @@ ...@@ -442,7 +445,7 @@
children = ( children = (
FC27990D21341016000B6BAD /* BoxCoder.metal */, FC27990D21341016000B6BAD /* BoxCoder.metal */,
4AF928812135673D005B6C3A /* ConcatKernel.metal */, 4AF928812135673D005B6C3A /* ConcatKernel.metal */,
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */, 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */,
4AF9288321357BE3005B6C3A /* Elementwise.metal */, 4AF9288321357BE3005B6C3A /* Elementwise.metal */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
...@@ -455,7 +458,7 @@ ...@@ -455,7 +458,7 @@
FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */, FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */,
FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */, FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */,
FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */, FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */,
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */, 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */,
FCA3A1642132A5EB00084FE5 /* Common.metal */, FCA3A1642132A5EB00084FE5 /* Common.metal */,
FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */, FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */,
FCA67CD42138272900BD58AA /* ConvAddMetal.metal */, FCA67CD42138272900BD58AA /* ConvAddMetal.metal */,
...@@ -477,7 +480,7 @@ ...@@ -477,7 +480,7 @@
FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */, FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */,
FC292C85214257CB00CF622F /* CPUCompute.h in Headers */, FC292C85214257CB00CF622F /* CPUCompute.h in Headers */,
FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */, FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */,
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */, 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */,
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */, FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
...@@ -617,6 +620,7 @@ ...@@ -617,6 +620,7 @@
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */,
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */,
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */,
FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
...@@ -657,7 +661,7 @@ ...@@ -657,7 +661,7 @@
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */, FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */,
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */, 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */,
FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */, FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */,
FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */, FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
......
...@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope) input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
if input.transpose != [0, 2, 3, 1] {
fatalError("batch norm only accepts NHWC")
}
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope) output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope) bias = try BatchNormParam.getFirstTensor(key: "Bias", map: opDesc.paraInputs, from: inScope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope) mean = try BatchNormParam.getFirstTensor(key: "Mean", map: opDesc.paraInputs, from: inScope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope) scale = try BatchNormParam.getFirstTensor(key: "Scale", map: opDesc.paraInputs, from: inScope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) variance = try BatchNormParam.getFirstTensor(key: "Variance", map: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
} catch let error { } catch let error {
...@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
} }
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
let inputBias: Tensor<ParamPrecisionType> let bias: Tensor<P>
let inputMean: Tensor<ParamPrecisionType> let mean: Tensor<P>
let inputScale: Tensor<ParamPrecisionType> let scale: Tensor<P>
let inputVariance: Tensor<ParamPrecisionType> let variance: Tensor<P>
let epsilon: Float let epsilon: Float
let momentum: Float let momentum: Float
} }
......
...@@ -14,7 +14,24 @@ ...@@ -14,7 +14,24 @@
import Foundation import Foundation
class FlattenOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{ class FlattenParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let axis: Int
}
class FlattenOp<P: PrecisionType>: Operator<FlattenKernel<P>, FlattenParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = FlattenOp<P> typealias OpType = FlattenOp<P>
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable { class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
// var newScale: MTLBuffer
// var newBias: MTLBuffer
//
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>) {
// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { let count = param.variance.dim.numel()
// fatalError() let varianceP = param.variance.data.pointer
// } let meanP = param.mean.data.pointer
// let scaleP = param.scale.data.pointer
// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { let biasP = param.scale.data.pointer
// fatalError() for i in 0..<count {
// } let invStd = P(1 / (Float32(varianceP[i]) + param.epsilon).squareRoot())
// self.newScale = newScale biasP[i] = biasP[i] - meanP[i] * invStd * scaleP[i]
// self.newBias = newBias scaleP[i] = invStd * scaleP[i]
// }
param.bias.initBuffer(device: device, precision: computePrecision)
param.scale.initBuffer(device: device, precision: computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
...@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable { ...@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
} else { } else {
fatalError() fatalError()
} }
//
// let varianceBuffer : MTLBuffer = param.inputVariance.buffer
//
// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
// }
//
// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
// let scale : MTLBuffer = param.inputScale.buffer
// let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
// let bias : MTLBuffer = param.inputBias.buffer
// let biasContents = bias.contents().assumingMemoryBound(to: P.self)
// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
//
// for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
// }
} }
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil") throw PaddleMobileError.predictError(message: " encoder is nil")
} }
// encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
// encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.output.metalTexture, index: 1)
// encoder.setBuffer(newScale, offset: 0, index: 0) encoder.setBuffer(param.scale.buffer, offset: 0, index: 0)
// encoder.setBuffer(newBias, offset: 0, index: 1) encoder.setBuffer(param.bias.buffer, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
......
...@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -122,10 +122,11 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: ConcatParam<P>) { required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "concat") super.init(device: device, inFunctionName: "concat_\(orank)_float")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "concat_half") super.init(device: device, inFunctionName: "concat_\(orank)_half")
} else { } else {
fatalError() fatalError()
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FlattenMetalParam {
var idim: (Int32, Int32, Int32, Int32)
var itrans: (Int32, Int32, Int32, Int32)
var odim: (Int32, Int32, Int32, Int32)
var otrans: (Int32, Int32, Int32, Int32)
}
class FlattenKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
}
let it: [Int32] = param.input.transpose.map { Int32($0) }
var od: [Int32] = [1, 1, 1, 1]
for i in 0..<param.output.tensorDim.cout() {
od[4-param.output.tensorDim.cout()+i] = Int32(param.output.tensorDim[i])
}
let ot: [Int32] = param.output.transpose.map { Int32($0) }
metalParam = FlattenMetalParam.init(
idim: (id[0], id[1], id[2], id[3]),
itrans: (it[0], it[1], it[2], it[3]),
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
)
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
assert(orank == 2)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_half")
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: FlattenParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
}
...@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
odim: (od[0], od[1], od[2], od[3]), odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3]) otrans: (ot[0], ot[1], ot[2], ot[3])
) )
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_half") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half")
} else { } else {
fatalError() fatalError()
} }
......
...@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -27,7 +27,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
} }
required init(device: MTLDevice, param: SplitParam<P>) { required init(device: MTLDevice, param: SplitParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision) // param.output.initTexture(device: device, computePrecision: computePrecision)
for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
}
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split") super.init(device: device, inFunctionName: "split")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
......
...@@ -15,28 +15,28 @@ ...@@ -15,28 +15,28 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
const device half4 * newScale [[buffer(0)]], const device float4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]], const device float4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z); const float4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z]; float4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]], kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
const device float4 * newScale [[buffer(0)]], const device half4 * newScale [[buffer(0)]],
const device float4 * newBias [[buffer(1)]], const device half4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const float4 input = inTexture.read(gid.xy, gid.z); const half4 input = inTexture.read(gid.xy, gid.z);
float4 output = input * newScale[gid.z] + newBias[gid.z]; half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
#ifndef D #ifdef P
#define D 4
#endif
#ifndef P
#define P float
#endif
#define CONCAT2(a, b) a ## b #define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b #define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c #define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f, d, p) CONCAT3_(f, d, p) #define FUNC(f, r, p) CONCAT3_(f, r, p)
#define VECTOR(p, n) CONCAT2(p, n) #define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_D(f, d) CONCAT2_(f, d) #define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)]], kernel void FUNC(concat, R, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]], texture2d_array<P, access::read> in1 [[texture(1)]],
texture2d_array<P, access::read> in2 [[texture(2)]], texture2d_array<P, access::read> in2 [[texture(2)]],
texture2d_array<P, access::read> in3 [[texture(3)]], texture2d_array<P, access::read> in3 [[texture(3)]],
...@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0) ...@@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
VECTOR(P, 4) r; VECTOR(P, 4) r;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
xyzn[3] = i; xyzn[3] = i;
#if D == 4 #if R == 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd); xyzn2abcd_4(cp.odim[3], xyzn, abcd);
#else #else
FUNC_D(xyzn2abcd, D)(xyzn, abcd); FUNC_R(xyzn2abcd, R)(xyzn, abcd);
#endif #endif
int k = abcd[cp.axis] - cp.offset; int k = abcd[cp.axis] - cp.offset;
int j = 0; int j = 0;
...@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0) ...@@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
int ta = cp.odim[cp.axis]; int ta = cp.odim[cp.axis];
abcd[cp.axis] = k; abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j]; cp.odim[cp.axis] = cp.vdim[j];
#if D == 4 #if R == 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn); abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else #else
FUNC_D(abcd2xyzn, D)(abcd, oxyzn); FUNC_R(abcd2xyzn, R)(abcd, oxyzn);
#endif #endif
cp.odim[cp.axis] = ta; cp.odim[cp.axis] = ta;
switch (j) { switch (j) {
...@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0) ...@@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array<P, access::read> in0 [[texture(0)
} }
out.write(r, gid.xy, gid.z); out.write(r, gid.xy, gid.z);
} }
#endif
...@@ -26,31 +26,31 @@ struct ConcatParam { ...@@ -26,31 +26,31 @@ struct ConcatParam {
}; };
#define P float #define P float
#define D 4 #define R 4
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 3 #define R 3
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 2 #define R 2
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 1 #define R 1
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#undef P #undef P
#define P half #define P half
#define D 4 #define R 4
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 3 #define R 3
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 2 #define R 2
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#define D 1 #define R 1
#include "ConcatKernel.metal.inc" #include "ConcatKernel.inc.metal"
#undef D #undef R
#undef P #undef P
#ifndef P #ifdef P
#define P float
#endif
#define CONCAT2(a, b) a ## b #define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b #define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c #define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d #define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f, d1, d2, p) CONCAT4_(f, d1, d2, p) #define FUNC(f, r1, r2, p) CONCAT4_(f, r1, r2, p)
#define VECTOR(p, n) CONCAT2(p, n) #define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_D(f, d) CONCAT2_(f, d) #define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]], kernel void FUNC(reshape, RIN, ROUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]], texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]], constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
...@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu ...@@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
VECTOR(P, 4) r; VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) { for (int n = 0; n < 4; n++) {
oxyzn[3] = n; oxyzn[3] = n;
#if DOUT == 4 #if ROUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd); xyzn2abcd_4(oC, oxyzn, oabcd);
#else #else
FUNC_D(xyzn2abcd, DOUT)(oxyzn, oabcd); FUNC_R(xyzn2abcd, ROUT)(oxyzn, oabcd);
#endif #endif
int tabcd[4]; int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd); invtrans(lrp.otrans, oabcd, tabcd);
...@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu ...@@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
index2abcd(lrp.idim, index, tabcd); index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd); trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn); abcd2xyzn(iC, iabcd, ixyzn);
#if DIN == 4 #if RIN == 4
abcd2xyzn_4(iC, iabcd, ixyzn); abcd2xyzn_4(iC, iabcd, ixyzn);
#else #else
FUNC_D(abcd2xyzn, DIN)(iabcd, ixyzn); FUNC_R(abcd2xyzn, RIN)(iabcd, ixyzn);
#endif #endif
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else { } else {
...@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu ...@@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array<P, access::read> inTextu
outTexture.write(r, gid.xy, gid.z); outTexture.write(r, gid.xy, gid.z);
} }
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONRITIONS OF ANY KINR, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
...@@ -25,127 +25,126 @@ struct ReshapeParam { ...@@ -25,127 +25,126 @@ struct ReshapeParam {
}; };
#define P float #define P float
#define DIN 4 #define RIN 4
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 3 #define RIN 3
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 2 #define RIN 2
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 1 #define RIN 1
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#undef P #undef P
#define P half #define P half
#define DIN 4 #define RIN 4
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 3 #define RIN 3
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 2 #define RIN 2
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#define DIN 1 #define RIN 1
#define DOUT 4 #define ROUT 4
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 3 #define ROUT 3
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 2 #define ROUT 2
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#define DOUT 1 #define ROUT 1
#include "ReshapeKernel.metal.inc" #include "ReshapeKernel.inc.metal"
#undef DOUT #undef ROUT
#undef DIN #undef RIN
#undef P #undef P
...@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam { ...@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
} }
output.padToFourDim = Dim.init(inDim: dim) output.padToFourDim = Dim.init(inDim: dim)
output.dim = output.padToFourDim output.dim = output.padToFourDim
// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
} }
let input: Texture<P> let input: Texture<P>
let shape: [Int32] let shape: [Int32]
// let inplace: Bool
var output: Texture<P> var output: Texture<P>
} }
......
...@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam { ...@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope) input = try ShapeParam.input(inputs: opDesc.inputs, from: inScope)
output = try ShapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error { } catch let error {
throw error throw error
} }
} }
var output: Texture<P> var output: Texture<P>
let input: Texture<P>
} }
class ShapeOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{ class ShapeOp<P: PrecisionType>: Operator<ShapeKernel<P>, ShapeParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P> typealias OpType = ShapeOp<P>
func inferShape() { func inferShape() {
// para.output.dim = para.input.dim // para.output.dim = para.input.dim
......
...@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam { ...@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope) input = try SplitParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try SplitParam.outputOut(outputs: opDesc.outputs, from: inScope) output = Texture<P>.init(device: input.metalTexture!.device, inDim: input.dim)
axis = try SplitParam.getAttr(key: "axis", attrs: opDesc.attrs)
sections = try SplitParam.getAttr(key: "sections", attrs: opDesc.attrs)
if axis < 0 {
axis = input.tensorDim.cout() + axis
}
guard let outlist = opDesc.outputs["Out"] else {
fatalError()
}
for out in outlist {
guard let variant = inScope[out], let v = variant as? Texture<P> else {
fatalError()
}
outputList.append(v)
sections.append(Int32(v.tensorDim.dims[axis]))
}
} catch let error { } catch let error {
throw error throw error
} }
} }
var axis: Int
let input: Texture<P>
var output: Texture<P> var output: Texture<P>
var outputList: [Texture<P>] = []
var sections: [Int32] = []
} }
class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{ class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册