未验证 提交 e0ce2b3c 编写于 作者: Y Yanzhan Yang 提交者: GitHub

implement winograd for metal depthwise 3x3 naively (#1596)

* implement winograd for metal depthwise 3x3 naively

* keep half4 in non-winograd depthwise 3x3 to ensure precision
上级 cf70102b
......@@ -391,7 +391,6 @@ kernel void depthwise_conv_add_relu_3x3_half(texture2d_array<half, access::sampl
const device half *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
......@@ -421,10 +420,249 @@ kernel void depthwise_conv_add_relu_3x3_half(texture2d_array<half, access::sampl
output.z += float(input.z) * float(weights[weithTo + 2 * kernelHXW + j]);
output.w += float(input.w) * float(weights[weithTo + 3 * kernelHXW + j]);
}
float4 relu = fmax(output, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
output = fmax(output, 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void depthwise_conv_add_relu_3x3_half_winograd(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
uint ow = outTexture.get_width();
uint oh = outTexture.get_height();
if (gid.x >= ow || gid.y >= oh) {
return;
}
uint tx = (gid.x / 2) * 2;
uint ty = (gid.y / 2) * 2;
uint tc = (gid.x % 2) * 2 + gid.y % 2;
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
half4 inputs[16];
inputs[0] = inTexture.sample(sample, float2(tx - 1, ty - 1), tc);
inputs[1] = inTexture.sample(sample, float2(tx, ty - 1), tc);
inputs[2] = inTexture.sample(sample, float2(tx + 1, ty - 1), tc);
inputs[3] = inTexture.sample(sample, float2(tx + 2, ty - 1), tc);
inputs[4] = inTexture.sample(sample, float2(tx - 1, ty), tc);
inputs[5] = inTexture.sample(sample, float2(tx, ty), tc);
inputs[6] = inTexture.sample(sample, float2(tx + 1, ty), tc);
inputs[7] = inTexture.sample(sample, float2(tx + 2, ty), tc);
inputs[8] = inTexture.sample(sample, float2(tx - 1, ty + 1), tc);
inputs[9] = inTexture.sample(sample, float2(tx, ty + 1), tc);
inputs[10] = inTexture.sample(sample, float2(tx + 1, ty + 1), tc);
inputs[11] = inTexture.sample(sample, float2(tx + 2, ty + 1), tc);
inputs[12] = inTexture.sample(sample, float2(tx - 1, ty + 2), tc);
inputs[13] = inTexture.sample(sample, float2(tx, ty + 2), tc);
inputs[14] = inTexture.sample(sample, float2(tx + 1, ty + 2), tc);
inputs[15] = inTexture.sample(sample, float2(tx + 2, ty + 2), tc);
half4 base = biase[tc];
half4 res[4] = {base, base, base, base};
half f[3][3];
const uint kernelHXW = 9;
uint weightTo = tc * kernelHXW * 4;
for (int c = 0; c < 4; ++c) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
f[i][j] = weights[weightTo++];
}
}
half I[16];
for (int i = 0; i < 16; ++i) {
I[i] = inputs[i][c];
}
half B[16];
half tmp1 = I[2] - I[10];
half tmp2 = I[1] - I[9];
B[0] = I[0] - I[8] - tmp1;
B[1] = tmp1 + tmp2;
B[2] = tmp1 - tmp2;
B[3] = I[3] - I[11] - tmp2;
tmp1 = I[6] + I[10];
tmp2 = I[5] + I[9];
B[4] = I[4] + I[8] - tmp1;
B[5] = tmp1 + tmp2;
B[6] = tmp1 - tmp2;
B[7] = I[7] + I[11] - tmp2;
tmp1 = I[6] - I[10];
tmp2 = I[5] - I[9];
B[8] = -I[4] + I[8] + tmp1;
B[9] = -tmp1 - tmp2;
B[10] = tmp2 - tmp1;
B[11] = tmp2 - I[7] + I[11];
tmp1 = I[6] - I[14];
tmp2 = I[5] - I[13];
B[12] = -I[4] + I[12] + tmp1;
B[13] = -tmp1 - tmp2;
B[14] = tmp2 - tmp1;
B[15] = tmp2 - I[7] + I[15];
half G[16];
G[0] = f[0][0];
G[1] = 0.5 * f[0][0] + 0.5 * f[0][1] + 0.5 * f[0][2];
G[2] = 0.5 * f[0][0] - 0.5 * f[0][1] + 0.5 * f[0][2];
G[3] = f[0][2];
G[4] = 0.5 * f[0][0] + 0.5 * f[1][0] + 0.5 * f[2][0];
G[5] = 0.25 * f[0][0] + 0.25 * f[0][1] + 0.25 * f[0][2] + 0.25 * f[1][0] + 0.25 * f[1][1] + 0.25 * f[1][2] + 0.25 * f[2][0] + 0.25 * f[2][1] + 0.25 * f[2][2];
G[6] = 0.25 * f[0][0] - 0.25 * f[0][1] + 0.25 * f[0][2] + 0.25 * f[1][0] - 0.25 * f[1][1] + 0.25 * f[1][2] + 0.25 * f[2][0] - 0.25 * f[2][1] + 0.25 * f[2][2];
G[7] = 0.5 * f[0][2] + 0.5 * f[1][2] + 0.5 * f[2][2];
G[8] = 0.5 * f[0][0] - 0.5 * f[1][0] + 0.5 * f[2][0];
G[9] = 0.25 * f[0][0] + 0.25 * f[0][1] + 0.25 * f[0][2] - 0.25 * f[1][0] - 0.25 * f[1][1] - 0.25 * f[1][2] + 0.25 * f[2][0] + 0.25 * f[2][1] + 0.25 * f[2][2];
G[10] = 0.25 * f[0][0] - 0.25 * f[0][1] + 0.25 * f[0][2] - 0.25 * f[1][0] + 0.25 * f[1][1] - 0.25 * f[1][2] + 0.25 * f[2][0] - 0.25 * f[2][1] + 0.25 * f[2][2];
G[11] = 0.5 * f[0][2] - 0.5 * f[1][2] + 0.5 * f[2][2];
G[12] = f[2][0];
G[13] = 0.5 * f[2][0] + 0.5 * f[2][1] + 0.5 * f[2][2];
G[14] = 0.5 * f[2][0] - 0.5 * f[2][1] + 0.5 * f[2][2];
G[15] = f[2][2];
half T[16];
for (int ii = 0; ii < 16; ++ii) {
T[ii] = B[ii] * G[ii];
}
tmp1 = T[1] + T[5] + T[9];
tmp2 = T[2] + T[6] + T[10];
res[0][c] += T[0] + T[4] + T[8] + tmp1 + tmp2;
res[1][c] += T[3] + T[7] + T[11] + tmp1 - tmp2;
tmp1 = T[5] - T[9] + T[13];
tmp2 = T[6] - T[10] + T[14];
res[2][c] += T[4] - T[8] + T[12] + tmp1 + tmp2;
res[3][c] += T[7] - T[11] + T[15] + tmp1 - tmp2;
}
outTexture.write(fmax(res[0], 0.0), uint2(tx, ty), tc);
outTexture.write(fmax(res[1], 0.0), uint2(tx + 1, ty), tc);
outTexture.write(fmax(res[2], 0.0), uint2(tx, ty + 1), tc);
outTexture.write(fmax(res[3], 0.0), uint2(tx + 1, ty + 1), tc);
}
//kernel void depthwise_conv_add_relu_3x3_half_winograd_naive(texture2d_array<half, access::sample> inTexture [[texture(0)]],
// texture2d_array<half, access::write> outTexture [[texture(1)]],
// constant MetalConvParam &param [[buffer(0)]],
// const device half *weights [[buffer(1)]],
// const device half4 *biase [[buffer(2)]],
// uint3 gid [[thread_position_in_grid]]) {
// uint ow = outTexture.get_width();
// uint oh = outTexture.get_height();
// if (gid.x >= ow || gid.y >= oh) {
// return;
// }
//
// uint tx = (gid.x / 2) * 2;
// uint ty = (gid.y / 2) * 2;
// uint tc = (gid.x % 2) * 2 + gid.y % 2;
//
// constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
// half4 inputs[4][4];
// inputs[0][0] = inTexture.sample(sample, float2(tx - 1, ty - 1), tc);
// inputs[0][1] = inTexture.sample(sample, float2(tx, ty - 1), tc);
// inputs[0][2] = inTexture.sample(sample, float2(tx + 1, ty - 1), tc);
// inputs[0][3] = inTexture.sample(sample, float2(tx + 2, ty - 1), tc);
//
// inputs[1][0] = inTexture.sample(sample, float2(tx - 1, ty), tc);
// inputs[1][1] = inTexture.sample(sample, float2(tx, ty), tc);
// inputs[1][2] = inTexture.sample(sample, float2(tx + 1, ty), tc);
// inputs[1][3] = inTexture.sample(sample, float2(tx + 2, ty), tc);
//
// inputs[2][0] = inTexture.sample(sample, float2(tx - 1, ty + 1), tc);
// inputs[2][1] = inTexture.sample(sample, float2(tx, ty + 1), tc);
// inputs[2][2] = inTexture.sample(sample, float2(tx + 1, ty + 1), tc);
// inputs[2][3] = inTexture.sample(sample, float2(tx + 2, ty + 1), tc);
//
// inputs[3][0] = inTexture.sample(sample, float2(tx - 1, ty + 2), tc);
// inputs[3][1] = inTexture.sample(sample, float2(tx, ty + 2), tc);
// inputs[3][2] = inTexture.sample(sample, float2(tx + 1, ty + 2), tc);
// inputs[3][3] = inTexture.sample(sample, float2(tx + 2, ty + 2), tc);
//
// const uint kernelHXW = 9;
// uint weightTo = tc * kernelHXW * 4;
//
// half f[3][3];
//
// half4 base = biase[tc];
// half4 res[2][2];
// res[0][0] = base;
// res[0][1] = base;
// res[1][0] = base;
// res[1][1] = base;
//
// for (int c = 0; c < 4; ++c) {
// for (int i = 0; i < 3; ++i) {
// for (int j = 0; j < 3; ++j) {
// f[i][j] = weights[weightTo++];
// }
// }
// half I[4][4];
// for (int ii = 0; ii < 4; ++ii) {
// for (int jj = 0; jj < 4; ++jj) {
// I[ii][jj] = inputs[ii][jj][c];
// }
// }
// half B[4][4];
// B[0][0] = I[0][0] - I[0][2] - I[2][0] + I[2][2];
// B[0][1] = I[0][1] + I[0][2] - I[2][1] - I[2][2];
// B[0][2] = -I[0][1] + I[0][2] + I[2][1] - I[2][2];
// B[0][3] = -I[0][1] + I[0][3] + I[2][1] - I[2][3];
// B[1][0] = I[1][0] - I[1][2] + I[2][0] - I[2][2];
// B[1][1] = I[1][1] + I[1][2] + I[2][1] + I[2][2];
// B[1][2] = -I[1][1] + I[1][2] - I[2][1] + I[2][2];
// B[1][3] = -I[1][1] + I[1][3] - I[2][1] + I[2][3];
// B[2][0] = -I[1][0] + I[1][2] + I[2][0] - I[2][2];
// B[2][1] = -I[1][1] - I[1][2] + I[2][1] + I[2][2];
// B[2][2] = I[1][1] - I[1][2] - I[2][1] + I[2][2];
// B[2][3] = I[1][1] - I[1][3] - I[2][1] + I[2][3];
// B[3][0] = -I[1][0] + I[1][2] + I[3][0] - I[3][2];
// B[3][1] = -I[1][1] - I[1][2] + I[3][1] + I[3][2];
// B[3][2] = I[1][1] - I[1][2] - I[3][1] + I[3][2];
// B[3][3] = I[1][1] - I[1][3] - I[3][1] + I[3][3];
// half G[4][4];
// G[0][0] = f[0][0];
// G[0][1] = 0.5 * f[0][0] + 0.5 * f[0][1] + 0.5 * f[0][2];
// G[0][2] = 0.5 * f[0][0] - 0.5 * f[0][1] + 0.5 * f[0][2];
// G[0][3] = f[0][2];
// G[1][0] = 0.5 * f[0][0] + 0.5 * f[1][0] + 0.5 * f[2][0];
// G[1][1] = 0.25 * f[0][0] + 0.25 * f[0][1] + 0.25 * f[0][2] + 0.25 * f[1][0] + 0.25 * f[1][1] + 0.25 * f[1][2] + 0.25 * f[2][0] + 0.25 * f[2][1] + 0.25 * f[2][2];
// G[1][2] = 0.25 * f[0][0] - 0.25 * f[0][1] + 0.25 * f[0][2] + 0.25 * f[1][0] - 0.25 * f[1][1] + 0.25 * f[1][2] + 0.25 * f[2][0] - 0.25 * f[2][1] + 0.25 * f[2][2];
// G[1][3] = 0.5 * f[0][2] + 0.5 * f[1][2] + 0.5 * f[2][2];
// G[2][0] = 0.5 * f[0][0] - 0.5 * f[1][0] + 0.5 * f[2][0];
// G[2][1] = 0.25 * f[0][0] + 0.25 * f[0][1] + 0.25 * f[0][2] - 0.25 * f[1][0] - 0.25 * f[1][1] - 0.25 * f[1][2] + 0.25 * f[2][0] + 0.25 * f[2][1] + 0.25 * f[2][2];
// G[2][2] = 0.25 * f[0][0] - 0.25 * f[0][1] + 0.25 * f[0][2] - 0.25 * f[1][0] + 0.25 * f[1][1] - 0.25 * f[1][2] + 0.25 * f[2][0] - 0.25 * f[2][1] + 0.25 * f[2][2];
// G[2][3] = 0.5 * f[0][2] - 0.5 * f[1][2] + 0.5 * f[2][2];
// G[3][0] = f[2][0];
// G[3][1] = 0.5 * f[2][0] + 0.5 * f[2][1] + 0.5 * f[2][2];
// G[3][2] = 0.5 * f[2][0] - 0.5 * f[2][1] + 0.5 * f[2][2];
// G[3][3] = f[2][2];
// half T[4][4];
// for (int ii = 0; ii < 4; ++ii) {
// for (int jj = 0; jj < 4; ++jj) {
// T[ii][jj] = B[ii][jj] * G[ii][jj];
// }
// }
// half A[2][2];
// A[0][0] = T[0][0] + T[0][1] + T[0][2] + T[1][0] + T[1][1] + T[1][2] + T[2][0] + T[2][1] + T[2][2];
// A[0][1] = T[0][1] - T[0][2] + T[0][3] + T[1][1] - T[1][2] + T[1][3] + T[2][1] - T[2][2] + T[2][3];
// A[1][0] = T[1][0] + T[1][1] + T[1][2] - T[2][0] - T[2][1] - T[2][2] + T[3][0] + T[3][1] + T[3][2];
// A[1][1] = T[1][1] - T[1][2] + T[1][3] - T[2][1] + T[2][2] - T[2][3] + T[3][1] - T[3][2] + T[3][3];
// for (int i = 0; i < 2; ++i) {
// for (int j = 0; j < 2; ++j) {
// res[i][j][c] += A[i][j];
// }
// }
// }
//
// for (int i = 0; i < 2; ++i) {
// for (int j = 0; j < 2; ++j) {
// half4 output = fmax(res[i][j], 0.0);
// outTexture.write(output, uint2(tx + j, ty + i), tc);
// }
// }
//}
kernel void conv_add_relu_5x1_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
......
......@@ -370,7 +370,7 @@ extension MTLDevice {
}
extension MTLComputeCommandEncoder {
public func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture) {
public func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture, groupDepth: Int? = nil) {
let slices = (outTexture.arrayLength * 4 + 3)/4
let width = computePipline.threadExecutionWidth
......@@ -382,8 +382,7 @@ extension MTLComputeCommandEncoder {
let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth ?? slices)
setComputePipelineState(computePipline)
......
......@@ -110,15 +110,18 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
}
var shouldUseMPS = false
let functionName = type(of: self).kernelFunctionName(param: param)
if #available(iOS 11.0, *), initContext.useMPS {
shouldUseMPS = true
}
if type(of: self).isWinoGrad(functionName: functionName) {
shouldUseMPS = false
}
if shouldUseMPS {
super.init(device: device, inFunctionName: nil, initContext: initContext)
setupWithMPS(device: device, param: param)
} else {
let functionName = type(of: self).kernelFunctionName(param: param)
if functionName == nil {
fatalError(" unsupport yet ")
}
......@@ -136,7 +139,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
return
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
......@@ -145,7 +147,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
encoder.setBuffer(param.y.buffer, offset: 0, index: 2)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture, groupDepth: type(of: self).isWinoGrad(functionName: functionName) ? 1 : nil)
encoder.endEncoding()
}
......@@ -194,6 +196,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1])
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, padWhenOneC: padWhenOneC)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
}
......@@ -234,5 +237,12 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
func neuronFilterForMPSLayer(device: MTLDevice) -> AnyObject? {
return nil
}
open class func isWinoGrad(functionName: String?) -> Bool {
if let functionName = functionName {
return functionName.hasSuffix("winograd")
}
return false
}
}
......@@ -15,7 +15,11 @@ class ConvAddReluKernel<P: PrecisionProtocol>: ConvAddKernel<P> {
if param.filter.width == 1 && param.filter.height == 1 {
return "conv_add_relu_1x1_half"
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
return "depthwise_conv_add_relu_3x3_half"
if param.filter.n == 16 && param.stride[0] == 1 && param.stride[1] == 1 && param.input.tensorDim[2] % 2 == 0 && param.input.tensorDim[3] % 2 == 0 && false {
return "depthwise_conv_add_relu_3x3_half_winograd"
} else {
return "depthwise_conv_add_relu_3x3_half"
}
} else if param.filter.width == 3 && param.filter.height == 3 {
return "conv_add_relu_3x3_half"
} else if param.filter.width == 1 && param.filter.height == 5 {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册