提交 8b87a6eb 编写于 作者: D dolphin8

box coder

上级 07bd08b0
......@@ -27,6 +27,11 @@ class BoxcoderParam<P: PrecisionType>: OpParam {
} catch let error {
throw error
}
assert(priorBox.transpose == [0, 1, 2, 3])
assert(priorBoxVar.transpose == [0, 1, 2, 3])
assert(targetBox.transpose == [0, 1, 2, 3])
assert(codeType == "decode_center_size") // encode_center_size is not implemented
assert((targetBox.tensorDim.cout() == 3) && (targetBox.tensorDim[0] == 1)) // N must be 1 (only handle batch size = 1)
}
let priorBox: Texture<P>
let priorBoxVar: Texture<P>
......
......@@ -14,18 +14,26 @@
import Foundation
struct BoxcoderMetalParam {
}
class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: BoxcoderParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
// encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setTexture(param.priorBox.metalTexture, index: 0)
encoder.setTexture(param.priorBoxVar.metalTexture, index: 1)
encoder.setTexture(param.targetBox.metalTexture, index: 2)
encoder.setTexture(param.output.metalTexture, index: 3)
var bmp = BoxcoderMetalParam.init()
encoder.setBytes(&bmp, length: MemoryLayout<BoxcoderMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: BoxcoderParam<P>) {
param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "priorbox")
}
}
......@@ -477,3 +477,18 @@ kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
}
out.write(r, gid.xy, gid.z);
}
kernel void boxcoder(texture2d_array<float, access::read> priorBox [[texture(0)]],
texture2d_array<float, access::read> priorBoxVar [[texture(1)]],
texture2d_array<float, access::read> targetBox [[texture(2)]],
texture2d_array<float, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
float4 t = targetBox.read(gid.xy, gid.z);
float4 p = priorBox.read(gid.xy, gid.z);
float4 pv = priorBoxVar.read(gid.xy, gid.z);
float ox = (p.z * pv.x * t.x + p.x) - t.z / 2;
float oy = (p.w * pv.y * t.y + p.y) - t.w / 2;
float ow = exp(pv.z * t.z) * p.z + t.z / 2;
float oh = exp(pv.w * t.w) * p.w + t.w / 2;
output.write(float4(ox, oy, ow, oh), gid.xy, gid.z);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册