提交 b0c3e1cb 编写于 作者: D dolphin8

fix batchnorm

上级 84a126bf
...@@ -19,27 +19,36 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable { ...@@ -19,27 +19,36 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
var newBias: MTLBuffer var newBias: MTLBuffer
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>) {
guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
fatalError()
}
guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
fatalError()
}
self.newScale = newScale
self.newBias = newBias
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm")
let varianceBuffer = param.inputVariance.buffer let varianceBuffer : MTLBuffer = param.inputVariance.buffer
var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) { for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
invStd[i] = 1 / Float32(varianceContents[i] + param.epsilon).squareRoot() invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
} }
var newScale = device.makeBuffer(param.inputScale.buffer.length)
var newBias = device.makeBuffer(param.inputBias.buffer.length) let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
var newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
var newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) let scale : MTLBuffer = param.inputScale.buffer
let scale = param.inputScale.buffer
let scaleContents = scale.contents().assumingMemoryBound(to: P.self) let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
let bias = param.inputBias.buffer let bias : MTLBuffer = param.inputBias.buffer
let biasContents = bias.contents().assumingMemoryBound(to: P.self) let biasContents = bias.contents().assumingMemoryBound(to: P.self)
let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(scaleContents.lengh / MemoryLayout<P>.stride) { for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
newScaleContents[i] = invStd[i] * scaleContents[i] newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
newBiasContents[i] = biasContents[i] - meanContents[i] * invStd[i] * scaleContents[i] newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册