diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BoxcoderOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BoxcoderOp.swift index 4cea0c7c58b8ff89196d49962eca2624d1fed6a1..f50664dd1a7c165c5df06965cdca992492a86650 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BoxcoderOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BoxcoderOp.swift @@ -27,6 +27,11 @@ class BoxcoderParam: OpParam { } catch let error { throw error } + assert(priorBox.transpose == [0, 1, 2, 3]) + assert(priorBoxVar.transpose == [0, 1, 2, 3]) + assert(targetBox.transpose == [0, 1, 2, 3]) + assert(codeType == "decode_center_size") // encode_center_size is not implemented + assert((targetBox.tensorDim.cout() == 3) && (targetBox.tensorDim[0] == 1)) // N must be 1 (only handle batch size = 1) } let priorBox: Texture

let priorBoxVar: Texture

diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BoxcoderKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BoxcoderKernel.swift index 461652e8c3e53114b36b89a4be785443d3756dcc..26d9c2cdf64ced43ec459b4a2de8304ff37dc222 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BoxcoderKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BoxcoderKernel.swift @@ -14,18 +14,26 @@ import Foundation +struct BoxcoderMetalParam { +} + class BoxcoderKernel: Kernel, Computable{ func compute(commandBuffer: MTLCommandBuffer, param: BoxcoderParam

) throws { guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encode is nil") } -// encoder.setTexture(param.input.metalTexture, index: 0) - encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setTexture(param.priorBox.metalTexture, index: 0) + encoder.setTexture(param.priorBoxVar.metalTexture, index: 1) + encoder.setTexture(param.targetBox.metalTexture, index: 2) + encoder.setTexture(param.output.metalTexture, index: 3) + var bmp = BoxcoderMetalParam.init() + encoder.setBytes(&bmp, length: MemoryLayout.size, index: 0) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.endEncoding() } required init(device: MTLDevice, param: BoxcoderParam

) { + param.output.initTexture(device: device) super.init(device: device, inFunctionName: "priorbox") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal index 234fbd0fd61fed4def08a3526a4b5fe17aae80fc..a00c7a71a466b6754d0aa52f94bf99bb03531373 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal @@ -477,3 +477,18 @@ kernel void concat(texture2d_array in0 [[texture(0)]], } out.write(r, gid.xy, gid.z); } + +kernel void boxcoder(texture2d_array priorBox [[texture(0)]], + texture2d_array priorBoxVar [[texture(1)]], + texture2d_array targetBox [[texture(2)]], + texture2d_array output[[texture(3)]], + uint3 gid [[thread_position_in_grid]]) { + float4 t = targetBox.read(gid.xy, gid.z); + float4 p = priorBox.read(gid.xy, gid.z); + float4 pv = priorBoxVar.read(gid.xy, gid.z); + float ox = (p.z * pv.x * t.x + p.x) - t.z / 2; + float oy = (p.w * pv.y * t.y + p.y) - t.w / 2; + float ow = exp(pv.z * t.z) * p.z + t.z / 2; + float oh = exp(pv.w * t.w) * p.w + t.w / 2; + output.write(float4(ox, oy, ow, oh), gid.xy, gid.z); +}