diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 3b45d97c307216b0277193b682ac030d1326a0db..3761dad60f0f8b20e3f95168445317a3e627ada9 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -48,6 +48,11 @@ class BatchNormOp: Operator, BatchNormParam } typealias OpType = BatchNormOp

func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift index 4c36543abe5e66aafd5ba1c4ce0ec7610fe14369..fb51966491b83fca1b3653127c50fc885ad54def 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift @@ -15,11 +15,44 @@ import Foundation class BatchNormKernel: Kernel, Computable { + var newScale: MTLBuffer + var newBias: MTLBuffer + required init(device: MTLDevice, param: BatchNormParam

) { super.init(device: device, inFunctionName: "batchnorm") + + let varianceBuffer = param.inputVariance.buffer + var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) + let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) + for i in 0..<(varianceBuffer.length / MemoryLayout

.stride) { + invStd[i] = 1 / Float32(varianceContents[i] + param.epsilon).squareRoot() + } + var newScale = device.makeBuffer(param.inputScale.buffer.length) + var newBias = device.makeBuffer(param.inputBias.buffer.length) + var newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) + var newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) + let scale = param.inputScale.buffer + let scaleContents = scale.contents().assumingMemoryBound(to: P.self) + let bias = param.inputBias.buffer + let biasContents = bias.contents().assumingMemoryBound(to: P.self) + let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) + + for i in 0..<(scaleContents.lengh / MemoryLayout

.stride) { + newScaleContents[i] = invStd[i] * scaleContents[i] + newBiasContents[i] = biasContents[i] - meanContents[i] * invStd[i] * scaleContents[i] + } } func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam

) throws { - + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encoder is nil") + } + print("BatchNorm compute") + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBuffer(newScale, offset: 0, index: 0) + encoder.setBuffer(newBias, offset: 0, index: 1) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal index 9b202174defee112ea3dfc5c48c95563a8408467..82b9fa2ab04e4107dac64997a77a4eae5c4cdcec 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal @@ -60,16 +60,16 @@ kernel void elementwise_add(texture2d_array inTexture [[text outTexture.write(input, gid.xy, gid.z); } - - - kernel void batchnorm(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - uint3 gid [[thread_position_in_grid]]) { + texture2d_array outTexture [[texture(1)]], + const device half4 * newScale [[buffer(0)]], + const device half4 * newBias [[buffer(1)]], + uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; const half4 input = inTexture.read(gid.xy, gid.z); + half4 output = input * newScale[gid.z] + newBias[gid.z]; outTexture.write(input, gid.xy, gid.z); } @@ -85,4 +85,3 @@ kernel void texture2d_to_2d_array(texture2d inTexture [[text } -