: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam) {
- super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3")
- let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
- let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
+ if param.filter.width == 1 && param.filter.height == 1 {
+ super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1")
+ } else if param.filter.channel == 1 {
+ super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_1x1")
+ } else {
+ super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3")
+ }
+
+
+ let offsetX = param.filter.width/2 - Int(param.paddings[0])
+ let offsetY = param.filter.height/2 - Int(param.paddings[1])
+
+ print("offset x: \(offsetX)")
+ print("offset y: \(offsetY)")
+
let offsetZ = 0.0
metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
@@ -69,6 +81,4 @@ class ConvAddBatchNormReluKernel: Kernel, Computable {
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
-
-
}
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
index 9ce39e91fd366ffa7bceb7c265a10c5c12bad60b..950abd47f3f98c3f1404c25bd0a572043086df5e 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddKernel.swift
@@ -16,7 +16,7 @@ import Foundation
class ConvAddKernel: Kernel, Computable {
required init(device: MTLDevice, param: ConvAddParam) {
- super.init(device: device, inFunctionName: "conv3x3")
+ super.init(device: device, inFunctionName: "conv_add_1x1")
}
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
index 9cb8400dc734311de03381f2b4b641c128551f2f..7286e11d9a618ef6943b5d8462dc3a3e07072e1f 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
+++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal
@@ -24,41 +24,6 @@ struct MetalConvParam {
};
-kernel void conv3x3(texture2d_array inTexture [[texture(0)]],
- texture2d_array outTexture [[texture(1)]],
- constant MetalConvParam ¶m [[buffer(0)]],
- const device half4 *weights [[buffer(1)]],
- uint3 gid [[thread_position_in_grid]]) {
- if (gid.x >= outTexture.get_width() ||
- gid.y >= outTexture.get_height() ||
- gid.z >= outTexture.get_array_size()) {
- return;
- }
-
- short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
- constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
- const uint wightSliceCount = 36;
- uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
- half4 output = 0.0;
- for (uint i = 0; i < inTexture.get_array_size(); ++i) {
- half4 input[9];
- input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
- input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
- input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
- input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
- input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
- input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
- input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
- input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
- input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
- for (int j = 0; j < 9; ++j) {
- half4 weight = weights[weithTo + wightSliceCount * i + j * 4];
- output += dot(input[j], weight);
- }
- }
- outTexture.write(output, gid.xy, gid.z);
-}
-
//kernel void conv_add_batch_norm_relu_3x3(texture2d_array inTexture [[texture(0)]],
// texture2d_array outTexture [[texture(1)]],
// constant MetalConvParam ¶m [[buffer(0)]],
@@ -119,30 +84,172 @@ kernel void conv_add_batch_norm_relu_3x3(texture2d_array
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
- const uint wightSliceCount = 36;
- uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
- float4 output = 0.0;
- for (uint i = 0; i < inTexture.get_array_size(); ++i) {
- float4 input[9];
- input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
- input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
- input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
- input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
- input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
- input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
- input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
- input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
- input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
+ const uint kernelHXW = 9;
+
+ uint input_arr_size = inTexture.get_array_size();
+ uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
+
+ float4 output = float4(0.0);
+
+ float4 input[9];
+ for (uint i = 0; i < input_arr_size; ++i) {
+ input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
+ input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
+ input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
+ input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
+ input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
+ input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
+ input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
+ input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
+ input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
- float4 weight = weights[weithTo + wightSliceCount * i + j * 4];
- output += dot(input[j], weight);
+ float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
+ output.x += dot(input[j], weight_x);
+
+ float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
+ output.y += dot(input[j], weight_y);
+
+ float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
+ output.z += dot(input[j], weight_z);
+
+ float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
+ output.w += dot(input[j], weight_w);
}
}
+ output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
+ outTexture.write(output, gid.xy, gid.z);
+}
+
+
+
+kernel void conv_add_batch_norm_relu_1x1(texture2d_array inTexture [[texture(0)]],
+ texture2d_array outTexture [[texture(1)]],
+ constant MetalConvParam ¶m [[buffer(0)]],
+ const device float4 *weights [[buffer(1)]],
+ const device float4 *biase [[buffer(2)]],
+ const device float4 *new_scale [[buffer(3)]],
+ const device float4 *new_biase [[buffer(4)]],
+ uint3 gid [[thread_position_in_grid]]) {
+ if (gid.x >= outTexture.get_width() ||
+ gid.y >= outTexture.get_height() ||
+ gid.z >= outTexture.get_array_size()) {
+ return;
+ }
+
+ short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
+ constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
+ const uint kernelHXW = 1;
+
+ uint input_arr_size = inTexture.get_array_size();
+ uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
+
+ float4 output = float4(0.0);
+
+ float4 input;
+ for (uint i = 0; i < input_arr_size; ++i) {
+ input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
+ float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
+ output.x += dot(input, weight_x);
+
+ float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
+ output.y += dot(input, weight_y);
+
+ float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
+ output.z += dot(input, weight_z);
+
+ float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
+ output.w += dot(input, weight_w);
+ }
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
+}
+
+
+kernel void conv_add_1x1(texture2d_array inTexture [[texture(0)]],
+ texture2d_array outTexture [[texture(1)]],
+ constant MetalConvParam ¶m [[buffer(0)]],
+ const device float4 *weights [[buffer(1)]],
+ const device float4 *biase [[buffer(2)]],
+ const device float4 *new_scale [[buffer(3)]],
+ const device float4 *new_biase [[buffer(4)]],
+ uint3 gid [[thread_position_in_grid]]) {
+ if (gid.x >= outTexture.get_width() ||
+ gid.y >= outTexture.get_height() ||
+ gid.z >= outTexture.get_array_size()) {
+ return;
+ }
+
+ short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
+ constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
+ const uint kernelHXW = 1;
+
+ uint input_arr_size = inTexture.get_array_size();
+ uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
+
+ float4 output = float4(0.0);
+
+ float4 input;
+ for (uint i = 0; i < input_arr_size; ++i) {
+ input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
+ float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
+ output.x += dot(input, weight_x);
+
+ float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
+ output.y += dot(input, weight_y);
+
+ float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
+ output.z += dot(input, weight_z);
+
+ float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
+ output.w += dot(input, weight_w);
+ }
+ output = output + biase[gid.z];
+ outTexture.write(output, gid.xy, gid.z);
}
+kernel void depthwise_conv_add_batch_norm_relu_1x1(texture2d_array inTexture [[texture(0)]],
+ texture2d_array outTexture [[texture(1)]],
+ constant MetalConvParam ¶m [[buffer(0)]],
+ const device float4 *weights [[buffer(1)]],
+ const device float4 *biase [[buffer(2)]],
+ const device float4 *new_scale [[buffer(3)]],
+ const device float4 *new_biase [[buffer(4)]],
+ uint3 gid [[thread_position_in_grid]]) {
+
+ if (gid.x >= outTexture.get_width() ||
+ gid.y >= outTexture.get_height() ||
+ gid.z >= outTexture.get_array_size()) {
+ return;
+ }
+ uint output_slice = gid.z;
+
+ short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
+ constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
+ const uint kernelHXW = 9;
+ uint weithTo = gid.z * kernelHXW;
+ float4 output = float4(0.0);
+ float4 inputs[9];
+ inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
+ inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
+ inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
+ inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
+ inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
+ inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
+ inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
+ inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
+ inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
+ for (int j = 0; j < 9; ++j) {
+ float4 input = inputs[j];
+ float4 weight = weights[weithTo + j];
+ output.x += input.x * weight.x;
+ output.y += input.y * weight.y;
+ output.z += input.z * weight.z;
+ output.w += input.w * weight.w;
+ }
+ output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
+ outTexture.write(output, gid.xy, gid.z);
+}
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
index f6215c16154a376bb0527de888ad5dd8b24f3555..65a62121a64db8ce38318cad50f957e5b2bdb91e 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift
@@ -27,7 +27,7 @@ struct MetalConvParam {
class ConvKernel: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam) {
- super.init(device: device, inFunctionName: "conv3x3")
+ super.init(device: device, inFunctionName: "conv_add_1x1")
let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
let offsetZ = 0.0
diff --git a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
index dee5d79aa9bc16f99a1e231b55f0b5d8d30fa0b6..92bfe88c4994791e1d11646cf8796b8a8461f176 100644
--- a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
+++ b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
@@ -98,6 +98,8 @@ class Tensor: Tensorial {
buffer = device.makeBuffer(length: count * MemoryLayout.stride)
if C == paddedC {
buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout
.stride)
+ } else if C == 1 {
+ buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout
.stride)
} else {
var tmpPointer = data.pointer
var dstPtr = buffer?.contents().bindMemory(to: P.self, capacity: count)
@@ -121,6 +123,37 @@ class Tensor: Tensorial {
data.release()
}
+ var width: Int {
+ get {
+ if dim.cout() == 4 {
+ return dim[1]
+ } else {
+ fatalError()
+ }
+ }
+ }
+
+ var height: Int {
+ get {
+ if dim.cout() == 4 {
+ return dim[2]
+ } else {
+ fatalError()
+ }
+ }
+ }
+
+ var channel: Int {
+ get {
+ if dim.cout() == 4 {
+ return dim[3]
+ } else {
+ fatalError()
+ }
+ }
+ }
+
+
func NCHW2NHWC(newPtr: UnsafeMutablePointer) {
let N = dim[0]
let C = dim[1]