提交 caec10b7 编写于 作者: D dolphin8

batch norm

上级 2eabb65d
...@@ -48,6 +48,11 @@ class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam ...@@ -48,6 +48,11 @@ class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam
} }
typealias OpType = BatchNormOp<P> typealias OpType = BatchNormOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
} }
} }
......
...@@ -15,11 +15,44 @@ ...@@ -15,11 +15,44 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable { class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
var newScale: MTLBuffer
var newBias: MTLBuffer
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>) {
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm")
let varianceBuffer = param.inputVariance.buffer
var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
invStd[i] = 1 / Float32(varianceContents[i] + param.epsilon).squareRoot()
}
var newScale = device.makeBuffer(param.inputScale.buffer.length)
var newBias = device.makeBuffer(param.inputBias.buffer.length)
var newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
var newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
let scale = param.inputScale.buffer
let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
let bias = param.inputBias.buffer
let biasContents = bias.contents().assumingMemoryBound(to: P.self)
let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(scaleContents.lengh / MemoryLayout<P>.stride) {
newScaleContents[i] = invStd[i] * scaleContents[i]
newBiasContents[i] = biasContents[i] - meanContents[i] * invStd[i] * scaleContents[i]
}
} }
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
print("BatchNorm compute")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBuffer(newScale, offset: 0, index: 0)
encoder.setBuffer(newBias, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
} }
} }
...@@ -60,16 +60,16 @@ kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[text ...@@ -60,16 +60,16 @@ kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[text
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) { const device half4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z); const half4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
...@@ -85,4 +85,3 @@ kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[text ...@@ -85,4 +85,3 @@ kernel void texture2d_to_2d_array(texture2d<half, access::read> inTexture [[text
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册