提交 d4c97eeb 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #823 from dolphin8/metal

box coder
...@@ -27,6 +27,11 @@ class BoxcoderParam<P: PrecisionType>: OpParam { ...@@ -27,6 +27,11 @@ class BoxcoderParam<P: PrecisionType>: OpParam {
} catch let error { } catch let error {
throw error throw error
} }
assert(priorBox.transpose == [0, 1, 2, 3])
assert(priorBoxVar.transpose == [0, 1, 2, 3])
assert(targetBox.transpose == [0, 1, 2, 3])
assert(codeType == "decode_center_size") // encode_center_size is not implemented
assert((targetBox.tensorDim.cout() == 3) && (targetBox.tensorDim[0] == 1)) // N must be 1 (only handle batch size = 1)
} }
let priorBox: Texture<P> let priorBox: Texture<P>
let priorBoxVar: Texture<P> let priorBoxVar: Texture<P>
......
...@@ -14,18 +14,26 @@ ...@@ -14,18 +14,26 @@
import Foundation import Foundation
struct BoxcoderMetalParam {
}
class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: BoxcoderParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: BoxcoderParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
} }
// encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.priorBox.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.priorBoxVar.metalTexture, index: 1)
encoder.setTexture(param.targetBox.metalTexture, index: 2)
encoder.setTexture(param.output.metalTexture, index: 3)
var bmp = BoxcoderMetalParam.init()
encoder.setBytes(&bmp, length: MemoryLayout<BoxcoderMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: BoxcoderParam<P>) { required init(device: MTLDevice, param: BoxcoderParam<P>) {
param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "priorbox") super.init(device: device, inFunctionName: "priorbox")
} }
} }
...@@ -477,3 +477,18 @@ kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]], ...@@ -477,3 +477,18 @@ kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
} }
out.write(r, gid.xy, gid.z); out.write(r, gid.xy, gid.z);
} }
kernel void boxcoder(texture2d_array<float, access::read> priorBox [[texture(0)]],
texture2d_array<float, access::read> priorBoxVar [[texture(1)]],
texture2d_array<float, access::read> targetBox [[texture(2)]],
texture2d_array<float, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
float4 t = targetBox.read(gid.xy, gid.z);
float4 p = priorBox.read(gid.xy, gid.z);
float4 pv = priorBoxVar.read(gid.xy, gid.z);
float ox = (p.z * pv.x * t.x + p.x) - t.z / 2;
float oy = (p.w * pv.y * t.y + p.y) - t.w / 2;
float ow = exp(pv.z * t.z) * p.z + t.z / 2;
float oh = exp(pv.w * t.w) * p.w + t.w / 2;
output.write(float4(ox, oy, ow, oh), gid.xy, gid.z);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册