提交 6e5e698d 编写于 作者: L liuruilong

fix crash

上级 dcef687f
...@@ -34,6 +34,6 @@ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -34,6 +34,6 @@ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: BoxcoderParam<P>) { required init(device: MTLDevice, param: BoxcoderParam<P>) {
param.output.initTexture(device: device) param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "priorbox") super.init(device: device, inFunctionName: "boxcoder")
} }
} }
...@@ -40,11 +40,11 @@ struct ConvBNReluTestParam: TestParam { ...@@ -40,11 +40,11 @@ struct ConvBNReluTestParam: TestParam {
class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable { class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
required init(device: MTLDevice, testParam: ConvBNReluTestParam) { required init(device: MTLDevice, testParam: ConvBNReluTestParam) {
if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 { if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1")
} else if testParam.filterSize.channel == 1 { } else if testParam.filterSize.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3")
} else { } else {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3")
} }
} }
...@@ -53,11 +53,11 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable { ...@@ -53,11 +53,11 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
required init(device: MTLDevice, param: ConvBNReluParam<P>) { required init(device: MTLDevice, param: ConvBNReluParam<P>) {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1")
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3")
} else { } else {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3")
} }
param.output.initTexture(device: device, transpose: [0, 2, 3, 1]) param.output.initTexture(device: device, transpose: [0, 2, 3, 1])
param.filter.initBuffer(device: device, precision: Tensor.BufferPrecision.Float32) param.filter.initBuffer(device: device, precision: Tensor.BufferPrecision.Float32)
...@@ -74,6 +74,8 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable { ...@@ -74,6 +74,8 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
print("offset y: \(offsetY)") print("offset y: \(offsetY)")
let offsetZ = 0.0 let offsetZ = 0.0
print(" fuck ")
metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3])) metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
var invs: [P] = [] var invs: [P] = []
......
...@@ -26,6 +26,6 @@ class MulticlassNMSKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -26,6 +26,6 @@ class MulticlassNMSKernel<P: PrecisionType>: Kernel, Computable{
} }
required init(device: MTLDevice, param: MulticlassNMSParam<P>) { required init(device: MTLDevice, param: MulticlassNMSParam<P>) {
super.init(device: device, inFunctionName: "priorbox") super.init(device: device, inFunctionName: "prior_box")
} }
} }
...@@ -33,7 +33,7 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -33,7 +33,7 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: PriorBoxMetalParam! var metalParam: PriorBoxMetalParam!
required init(device: MTLDevice, param: PriorBoxParam<P>) { required init(device: MTLDevice, param: PriorBoxParam<P>) {
super.init(device: device, inFunctionName: "priorbox") super.init(device: device, inFunctionName: "prior_box")
param.output.initTexture(device: device, transpose: [2, 0, 1, 3]) param.output.initTexture(device: device, transpose: [2, 0, 1, 3])
param.outputVariances.initTexture(device: device, transpose: [2, 0, 1, 3]) param.outputVariances.initTexture(device: device, transpose: [2, 0, 1, 3])
......
...@@ -32,6 +32,7 @@ class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -32,6 +32,7 @@ class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{
} }
required init(device: MTLDevice, param: FeedParam<P>) { required init(device: MTLDevice, param: FeedParam<P>) {
param.output.initTexture(device: device, transpose: [0, 2, 3, 1])
super.init(device: device, inFunctionName: "texture2d_to_2d_array") super.init(device: device, inFunctionName: "texture2d_to_2d_array")
} }
} }
...@@ -699,3 +699,144 @@ kernel void depthwise_conv_add_3x3(texture2d_array<float, access::sample> inText ...@@ -699,3 +699,144 @@ kernel void depthwise_conv_add_3x3(texture2d_array<float, access::sample> inText
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
#pragma mark - conv bn relu
kernel void conv_batch_norm_relu_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float *weights [[buffer(1)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
...@@ -27,7 +27,7 @@ class PriorBoxParam<P: PrecisionType>: OpParam { ...@@ -27,7 +27,7 @@ class PriorBoxParam<P: PrecisionType>: OpParam {
aspectRatios = try PriorBoxParam.getAttr(key: "aspect_ratios", attrs: opDesc.attrs) aspectRatios = try PriorBoxParam.getAttr(key: "aspect_ratios", attrs: opDesc.attrs)
variances = try PriorBoxParam.getAttr(key: "variances", attrs: opDesc.attrs) variances = try PriorBoxParam.getAttr(key: "variances", attrs: opDesc.attrs)
flip = try PriorBoxParam.getAttr(key: "flip", attrs: opDesc.attrs) flip = try PriorBoxParam.getAttr(key: "flip", attrs: opDesc.attrs)
clip = try PriorBoxParam.getAttr(key: "clop", attrs: opDesc.attrs) clip = try PriorBoxParam.getAttr(key: "clip", attrs: opDesc.attrs)
stepW = try PriorBoxParam.getAttr(key: "step_w", attrs: opDesc.attrs) stepW = try PriorBoxParam.getAttr(key: "step_w", attrs: opDesc.attrs)
stepH = try PriorBoxParam.getAttr(key: "step_h", attrs: opDesc.attrs) stepH = try PriorBoxParam.getAttr(key: "step_h", attrs: opDesc.attrs)
offset = try PriorBoxParam.getAttr(key: "offset", attrs: opDesc.attrs) offset = try PriorBoxParam.getAttr(key: "offset", attrs: opDesc.attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册