提交 4fdf6612 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #881 from codeWorm2015/metal

 fix conv add 1x5
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
let testTo = 54 let testTo = 61
var isTest = false var isTest = false
...@@ -133,7 +133,8 @@ public class Executor<P: PrecisionType> { ...@@ -133,7 +133,8 @@ public class Executor<P: PrecisionType> {
print(" 第 \(i) 个 op: ") print(" 第 \(i) 个 op: ")
op.delogOutput() op.delogOutput()
} }
// self.ops[53].delogOutput() // self.ops[59].delogOutput()
// self.ops[60].delogOutput()
return return
......
...@@ -17,6 +17,7 @@ import Foundation ...@@ -17,6 +17,7 @@ import Foundation
class ConvAddKernel<P: PrecisionType>: Kernel, Computable { class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddParam<P>) { required init(device: MTLDevice, param: ConvAddParam<P>) {
if computePrecision == .Float16 { if computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1_half") super.init(device: device, inFunctionName: "conv_add_1x1_half")
...@@ -30,6 +31,8 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -30,6 +31,8 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
super.init(device: device, inFunctionName: "conv_add_1x1") super.init(device: device, inFunctionName: "conv_add_1x1")
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3")
} else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1")
} else { } else {
super.init(device: device, inFunctionName: "conv_add_3x3") super.init(device: device, inFunctionName: "conv_add_3x3")
} }
...@@ -37,12 +40,12 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -37,12 +40,12 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
fatalError() fatalError()
} }
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) let offsetY = (Int(param.dilations[1]) * (param.filter.height - 1) + 1)/2 - Int(param.paddings[1])
let offsetX = (Int(param.dilations[0]) * (param.filter.width - 1) + 1)/2 - Int(param.paddings[0]) let offsetX = (Int(param.dilations[0]) * (param.filter.width - 1) + 1)/2 - Int(param.paddings[0])
let offsetY = (Int(param.dilations[1]) * (param.filter.height - 1) + 1)/2 - Int(param.paddings[1]) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) param.filter.initBuffer(device: device, precision: computePrecision)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: computePrecision)
......
...@@ -93,15 +93,6 @@ kernel void conv_add_3x3(texture2d_array<float, access::sample> inTexture [[text ...@@ -93,15 +93,6 @@ kernel void conv_add_3x3(texture2d_array<float, access::sample> inTexture [[text
float4 input[9]; float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) { for (uint i = 0; i < input_arr_size; ++i) {
// input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
// input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
// input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
// input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
// input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
// input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
// input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
// input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
// input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i); input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i); input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
...@@ -138,8 +129,7 @@ kernel void conv_add_3x3(texture2d_array<float, access::sample> inTexture [[text ...@@ -138,8 +129,7 @@ kernel void conv_add_3x3(texture2d_array<float, access::sample> inTexture [[text
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void conv_add_5x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]], constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]], const device float4 *weights [[buffer(1)]],
...@@ -152,14 +142,12 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [ ...@@ -152,14 +142,12 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [
return; return;
} }
if (gid.x > 0 || gid.y > 0 || gid.z > 0) { return; }
ushort2 stride = ushort2(param.strideX, param.strideY); ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY); const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9; const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size(); uint input_arr_size = inTexture.get_array_size();
...@@ -167,32 +155,21 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [ ...@@ -167,32 +155,21 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [
float4 output = float4(0.0); float4 output = float4(0.0);
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY; ushort dilation_y = param.dilationY;
float4 input[5];
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) { for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 2 * dilation_y), i);
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i); input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i); input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i); input[3] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i); input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 2 * dilation_y), i);
for (int j = 0; j < 9; ++j) { for (int j = 0; j < 5; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i]; float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x); output.x += dot(input[j], weight_x);
...@@ -206,10 +183,11 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [ ...@@ -206,10 +183,11 @@ kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [
output.w += dot(input[j], weight_w); output.w += dot(input[j], weight_w);
} }
} }
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void depthwise_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]], kernel void depthwise_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]], constant MetalConvParam &param [[buffer(0)]],
...@@ -390,3 +368,78 @@ kernel void depthwise_conv_add_3x3_half(texture2d_array<half, access::sample> in ...@@ -390,3 +368,78 @@ kernel void depthwise_conv_add_3x3_half(texture2d_array<half, access::sample> in
output = output + float4(biase[gid.z]); output = output + float4(biase[gid.z]);
outTexture.write(half4(output), gid.xy, gid.z); outTexture.write(half4(output), gid.xy, gid.z);
} }
kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
if (gid.x > 0 || gid.y > 0 || gid.z > 0) { return; }
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY;
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册