提交 d123b52e 编写于 作者: D dolphin8

fix batchnorm

上级 caec10b7
......@@ -19,27 +19,36 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
var newBias: MTLBuffer
required init(device: MTLDevice, param: BatchNormParam<P>) {
guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
fatalError()
}
guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
fatalError()
}
self.newScale = newScale
self.newBias = newBias
super.init(device: device, inFunctionName: "batchnorm")
let varianceBuffer = param.inputVariance.buffer
let varianceBuffer : MTLBuffer = param.inputVariance.buffer
var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
invStd[i] = 1 / Float32(varianceContents[i] + param.epsilon).squareRoot()
invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
}
var newScale = device.makeBuffer(param.inputScale.buffer.length)
var newBias = device.makeBuffer(param.inputBias.buffer.length)
var newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
var newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
let scale = param.inputScale.buffer
let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
let scale : MTLBuffer = param.inputScale.buffer
let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
let bias = param.inputBias.buffer
let bias : MTLBuffer = param.inputBias.buffer
let biasContents = bias.contents().assumingMemoryBound(to: P.self)
let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(scaleContents.lengh / MemoryLayout<P>.stride) {
newScaleContents[i] = invStd[i] * scaleContents[i]
newBiasContents[i] = biasContents[i] - meanContents[i] * invStd[i] * scaleContents[i]
for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册