提交 6e942f3d 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #981 from dolphin8/metal

......@@ -22,6 +22,7 @@
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; };
4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA5214B5F6800D0F791 /* Shape.metal */; };
4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */; };
4AA1EAAA214F53D800D0F791 /* BoxCoder.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA9214F53D800D0F791 /* BoxCoder.inc.metal */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
......@@ -136,6 +137,7 @@
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = "<group>"; };
4AA1EAA5214B5F6800D0F791 /* Shape.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Shape.metal; sourceTree = "<group>"; };
4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.inc.metal; sourceTree = "<group>"; };
4AA1EAA9214F53D800D0F791 /* BoxCoder.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.inc.metal; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
......@@ -454,6 +456,7 @@
FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
4AA1EAA9214F53D800D0F791 /* BoxCoder.inc.metal */,
4AA1EAA5214B5F6800D0F791 /* Shape.metal */,
4AA1EA8F214664CD00D0F791 /* Split.metal */,
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */,
......@@ -583,6 +586,7 @@
buildActionMask = 2147483647;
files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
4AA1EAAA214F53D800D0F791 /* BoxCoder.inc.metal in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FCA67CD7213827AC00BD58AA /* ConvAddBNReluKernel.metal in Sources */,
4AF9287921341661005B6C3A /* Softmax.metal in Sources */,
......@@ -27,6 +27,10 @@ class BoxcoderParam<P: PrecisionType>: OpParam {
} catch let error {
throw error
assert(priorBox.tensorDim.cout() == 2)
assert(priorBoxVar.tensorDim.cout() == 2)
assert(targetBox.tensorDim.cout() == 3)
assert(output.tensorDim.cout() == 3)
assert(priorBox.transpose == [0, 1, 2, 3])
assert(priorBoxVar.transpose == [0, 1, 2, 3])
assert(targetBox.transpose == [0, 1, 2, 3])
......@@ -59,30 +63,19 @@ class BoxcoderOp<P: PrecisionType>: Operator<BoxcoderKernel<P>, BoxcoderParam<P>
func delogOutput() {
print(" \(type) output: ")
// let priorBoxpadToFourDim = para.priorBox.padToFourDim
// let priorBoxArray: [Float32] = para.priorBox.metalTexture.realNHWC(dim: (n: priorBoxpadToFourDim[0], h: priorBoxpadToFourDim[1], w: priorBoxpadToFourDim[2], c: priorBoxpadToFourDim[3]))
// print(" prior box ")
// print(priorBoxArray.strideArray())
// let priorBoxVarpadToFourDim = para.priorBoxVar.padToFourDim
// let priorBoxVarArray: [Float32] = para.priorBoxVar.metalTexture.realNHWC(dim: (n: priorBoxVarpadToFourDim[0], h: priorBoxVarpadToFourDim[1], w: priorBoxVarpadToFourDim[2], c: priorBoxVarpadToFourDim[3]))
// print(" prior box var ")
// print(priorBoxVarArray.strideArray())
// let targetBoxpadToFourDim = para.targetBox.padToFourDim
// let targetBoxArray: [Float32] = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxpadToFourDim[0], h: targetBoxpadToFourDim[1], w: targetBoxpadToFourDim[2], c: targetBoxpadToFourDim[3]))
// print(" target box ")
// print(targetBoxArray.strideArray())
let targetBoxpadToFourDim = para.targetBox.padToFourDim
let targetBoxArray = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxpadToFourDim[0], h: targetBoxpadToFourDim[1], w: targetBoxpadToFourDim[2], c: targetBoxpadToFourDim[3]))
let device = para.output.metalTexture!.device
let pbv : [Float32] = device.texture2tensor(texture: para.priorBoxVar.metalTexture!, dim: para.priorBoxVar.tensorDim.dims, transpose: para.priorBoxVar.transpose)
let pb : [Float32] = device.texture2tensor(texture: para.priorBox.metalTexture!, dim: para.priorBox.tensorDim.dims, transpose: para.priorBox.transpose)
let tb : [Float32] = device.texture2tensor(texture: para.targetBox.metalTexture!, dim: para.targetBox.tensorDim.dims, transpose: para.targetBox.transpose)
let out : [Float32] = device.texture2tensor(texture: para.output.metalTexture!, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
print(" prior box var ")
print(" target box ")
let padToFourDim = para.output.padToFourDim
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(" prior box ")
print(" output ")
......@@ -33,9 +33,9 @@ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: BoxcoderParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "boxcoder")
super.init(device: device, inFunctionName: "boxcoder_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "boxcoder_half")
} else {
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(boxcoder, P)(texture2d_array<P, access::read> priorBox [[texture(0)]],
texture2d_array<P, access::read> priorBoxVar [[texture(1)]],
texture2d_array<P, access::read> targetBox [[texture(2)]],
texture2d_array<P, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) p = priorBox.read(uint2(0, gid.x), gid.z);
VECTOR(P, 4) pv = priorBoxVar.read(uint2(0, gid.x), gid.z);
VECTOR(P, 4) t;
t[0] = targetBox.read(uint2(0, gid.x), gid.z)[0];
t[1] = targetBox.read(uint2(1, gid.x), gid.z)[0];
t[2] = targetBox.read(uint2(2, gid.x), gid.z)[0];
t[3] = targetBox.read(uint2(3, gid.x), gid.z)[0];
P px = (p.x + p.z) / 2;
P py = (p.y + p.w) / 2;
P pw = p.z - p.x;
P ph = p.w - p.y;
P tx = pv.x * t.x * pw + px;
P ty = pv.y * t.y * ph + py;
P tw = exp(pv.z * t.z) * pw;
P th = exp(pv.w * t.w) * ph;
VECTOR(P, 4) r;
r.x = tx - tw / 2;
r.y = ty - th / 2;
r.z = tx + tw / 2;
r.w = ty + th / 2;
output.write(r, gid.xy, gid.z);
......@@ -15,58 +15,9 @@
#include <metal_stdlib>
using namespace metal;
kernel void boxcoder(texture2d_array<float, access::read> priorBox [[texture(0)]],
texture2d_array<float, access::read> priorBoxVar [[texture(1)]],
texture2d_array<float, access::read> targetBox [[texture(2)]],
texture2d_array<float, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
float4 t = targetBox.read(gid.xy, gid.z);
float4 p = priorBox.read(gid.xy, gid.z);
float4 pv = priorBoxVar.read(gid.xy, gid.z);
float px = (p.x + p.z) / 2;
float py = (p.y + p.w) / 2;
float pw = p.z - p.x;
float ph = p.w - p.y;
float tx = pv.x * t.x * pw + px;
float ty = pv.y * t.y * ph + py;
float tw = exp(pv.z * t.z) * pw;
float th = exp(pv.w * t.w) * ph;
float4 r;
r.x = tx - tw / 2;
r.y = ty - th / 2;
r.z = tx + tw / 2;
r.w = ty + th / 2;
output.write(r, gid.xy, gid.z);
kernel void boxcoder_half(texture2d_array<half, access::read> priorBox [[texture(0)]],
texture2d_array<half, access::read> priorBoxVar [[texture(1)]],
texture2d_array<half, access::read> targetBox [[texture(2)]],
texture2d_array<half, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
half4 t = targetBox.read(gid.xy, gid.z);
half4 p = priorBox.read(gid.xy, gid.z);
half4 pv = priorBoxVar.read(gid.xy, gid.z);
float px = (float(p.x) + float(p.z)) / 2;
float py = (float(p.y) + float(p.w)) / 2;
float pw = float(p.z) - float(p.x);
float ph = float(p.w) - float(p.y);
float tx = float(pv.x) * float(t.x) * pw + px;
float ty = float(pv.y) * float(t.y) * ph + py;
float tw = exp(float(pv.z) * float(t.z)) * pw;
float th = exp(float(pv.w) * float(t.w)) * ph;
float4 r;
r.x = tx - tw / 2;
r.y = ty - th / 2;
r.z = tx + tw / 2;
r.w = ty + th / 2;
output.write(half4(r), gid.xy, gid.z);
#define P float
#include "BoxCoder.inc.metal"
#undef P
#define P half
#include "BoxCoder.inc.metal"
#undef P
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册