提交 876c1291 编写于 作者: D dolphin8

concat

上级 e71320da
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; }; 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; };
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; }; 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; };
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
...@@ -130,6 +131,7 @@ ...@@ -130,6 +131,7 @@
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = "<group>"; }; 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = "<group>"; };
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = "<group>"; }; 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = "<group>"; };
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; }; 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
...@@ -451,6 +453,7 @@ ...@@ -451,6 +453,7 @@
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
4AA1EA8F214664CD00D0F791 /* Split.metal */, 4AA1EA8F214664CD00D0F791 /* Split.metal */,
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */,
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */, 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */,
4AF9287821341661005B6C3A /* Softmax.metal */, 4AF9287821341661005B6C3A /* Softmax.metal */,
FCEB6849212F00DB00D2448E /* PreluKernel.metal */, FCEB6849212F00DB00D2448E /* PreluKernel.metal */,
...@@ -584,6 +587,7 @@ ...@@ -584,6 +587,7 @@
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */, FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */, 4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
......
...@@ -83,38 +83,38 @@ public class PaddleMobileUnitTest { ...@@ -83,38 +83,38 @@ public class PaddleMobileUnitTest {
} }
public func testConcat() { public func testConcat() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil" // let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
var it: [[Float32]] = [] // var it: [[Float32]] = []
for _ in 0..<7 { // for _ in 0..<7 {
it.append((0..<12).map { Float32($0) }) // it.append((0..<12).map { Float32($0) })
} // }
let input = it.map { device.tensor2texture(value: $0, dim: [3, 4]) } // let input = it.map { device.tensor2texture(value: $0, dim: [3, 4]) }
let output = device.tensor2texture(value: [Float32](), dim: [3, 28]) // let output = device.tensor2texture(value: [Float32](), dim: [3, 28])
//
let param = ConcatTestParam.init( // let param = ConcatTestParam.init(
input: input, // input: input,
output: output, // output: output,
dims: [[3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]], // dims: [[3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]],
axis: 1, // axis: 1,
odim: [3, 28] // odim: [3, 28]
) // )
let concatKernel = ConcatKernel<Float32>.init(device: device, testParam: param) // let concatKernel = ConcatKernel<Float32>.init(device: device, testParam: param)
concatKernel.test(cmdBuffer: buffer, param: param) // concatKernel.test(cmdBuffer: buffer, param: param)
buffer.addCompletedHandler { (buffer) in // buffer.addCompletedHandler { (buffer) in
for i in 0..<it.count { // for i in 0..<it.count {
let _: Float32? = input[i].logDesc() // let _: Float32? = input[i].logDesc()
self.tensorPrint(tensor: it[i], dim: [3, 4]) // self.tensorPrint(tensor: it[i], dim: [3, 4])
} // }
let _: Float32? = output.logDesc() // let _: Float32? = output.logDesc()
let tx: [Float32] = self.device.texture2tensor(texture: output, dim: [3, 28]) // let tx: [Float32] = self.device.texture2tensor(texture: output, dim: [3, 28])
self.tensorPrint(tensor: tx, dim: [3, 28]) // self.tensorPrint(tensor: tx, dim: [3, 28])
} // }
//
buffer.commit() // buffer.commit()
} }
public func testReshape() { public func testReshape() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil" // let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
// let input: [Float32] = (0..<24).map { Float32($0) } // let input: [Float32] = (0..<24).map { Float32($0) }
// let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4]) // let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
// let outTexture = device.tensor2texture(value: [Float32](), dim: [4, 6]) // let outTexture = device.tensor2texture(value: [Float32](), dim: [4, 6])
...@@ -139,32 +139,32 @@ public class PaddleMobileUnitTest { ...@@ -139,32 +139,32 @@ public class PaddleMobileUnitTest {
// self.tensorPrint(tensor: tx, dim: [4, 6]) // self.tensorPrint(tensor: tx, dim: [4, 6])
// } // }
let input: [Float32] = (0..<24).map { Float32($0) } // let input: [Float32] = (0..<24).map { Float32($0) }
let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4]) // let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
let outTexture = device.tensor2texture(value: [Float32](), dim: [24]) // let outTexture = device.tensor2texture(value: [Float32](), dim: [24])
let mp = ReshapeMetalParam.init( // let mp = ReshapeMetalParam.init(
idim: (1, 2, 3, 4), // idim: (1, 2, 3, 4),
itrans: (0, 1, 2, 3), // itrans: (0, 1, 2, 3),
odim: (1, 1, 1, 24), // odim: (1, 1, 1, 24),
otrans: (0, 1, 2, 3) // otrans: (0, 1, 2, 3)
) // )
let param = ReshapeTestParam.init( // let param = ReshapeTestParam.init(
inputTexture: inTexture, // inputTexture: inTexture,
outputTexture: outTexture, // outputTexture: outTexture,
param: mp // param: mp
) // )
let reshapeKernel = ReshapeKernel<Float32>.init(device: device, testParam: param) // let reshapeKernel = ReshapeKernel<Float32>.init(device: device, testParam: param)
reshapeKernel.test(commandBuffer: buffer, testParam: param) // reshapeKernel.test(commandBuffer: buffer, testParam: param)
buffer.addCompletedHandler { (buffer) in // buffer.addCompletedHandler { (buffer) in
let _: Float32? = inTexture.logDesc() // let _: Float32? = inTexture.logDesc()
let _: Float32? = outTexture.logDesc() // let _: Float32? = outTexture.logDesc()
self.tensorPrint(tensor: input, dim: [2, 3, 4]) // self.tensorPrint(tensor: input, dim: [2, 3, 4])
let tx: [Float32] = self.device.texture2tensor(texture: outTexture, dim: [24]) // let tx: [Float32] = self.device.texture2tensor(texture: outTexture, dim: [24])
self.tensorPrint(tensor: tx, dim: [24]) // self.tensorPrint(tensor: tx, dim: [24])
} // }
//
//
buffer.commit() // buffer.commit()
} }
public func testTranspose() { public func testTranspose() {
......
...@@ -30,7 +30,7 @@ public class MobileNet_ssd_AR: Net{ ...@@ -30,7 +30,7 @@ public class MobileNet_ssd_AR: Net{
class MobilenetssdPreProccess: CusomKernel { class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = CusomKernel.Shape.init(inWidth: 160, inHeight: 160, inChannel: 3) let s = CusomKernel.Shape.init(inWidth: 160, inHeight: 160, inChannel: 3)
super.init(device: device, inFunctionName: "mobilent_ar_preprocess_half", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilent_ar_preprocess", outputDim: s, usePaddleMobileLib: false)
} }
} }
......
...@@ -19,15 +19,15 @@ class BilinearInterpParam<P: PrecisionType>: OpParam { ...@@ -19,15 +19,15 @@ class BilinearInterpParam<P: PrecisionType>: OpParam {
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope) input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope)
// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
// fatalError()
// }
output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope)
out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs) out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs)
out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs) out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
fatalError()
}
} }
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
...@@ -53,6 +53,15 @@ class BilinearInterpOp<P: PrecisionType>: Operator<BilinearInterpKernel<P>, Bili ...@@ -53,6 +53,15 @@ class BilinearInterpOp<P: PrecisionType>: Operator<BilinearInterpKernel<P>, Bili
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] {
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3])).strideArray())
} else {
fatalError(" not implemet")
}
} }
} }
......
...@@ -31,102 +31,111 @@ struct ConcatMetalParam { ...@@ -31,102 +31,111 @@ struct ConcatMetalParam {
} }
class ConcatKernel<P: PrecisionType>: Kernel, Computable{ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
var v = "normal"
func encodeTest(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatTestParam, _ istart: Int, _ iend: Int) { var pm = ConcatMetalParam.init()
let encoder = cmdBuffer.makeComputeCommandEncoder()! func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
var p = ConcatMetalParam.init()
var odim: [Int32] = [1, 1, 1, 1] guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
for i in 0..<param.odim.count { throw PaddleMobileError.predictError(message: " encode is nil")
odim[4-param.odim.count+i] = Int32(param.odim[i])
}
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.odim.count + param.axis)
for i in 0..<istart {
p.offset += Int32(param.dims[i][param.axis])
} }
var vdim: [Int32] = [] let num = param.input.count
for i in 0..<(iend - istart) { for i in 0..<num {
encoder.setTexture(param.input[i+istart], index: i) encoder.setTexture(param.input[i].metalTexture, index: i)
vdim.append(Int32(param.dims[i+istart][Int(param.axis)]))
} }
for i in (iend-istart)..<6 { encoder.setTexture(param.output.metalTexture, index: num)
encoder.setTexture(param.input[0], index: i) if v == "normal" {
vdim.append(0) encoder.setTexture(param.output.metalTexture, index: num + 1)
} }
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5]) encoder.setBytes(&pm, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.setTexture(param.output, index: 6) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.setTexture(param.output, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding() encoder.endEncoding()
} }
func encode(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatParam<P>, _ istart: Int, _ iend: Int) throws { required init(device: MTLDevice, param: ConcatParam<P>) {
guard let encoder = cmdBuffer.makeComputeCommandEncoder() else { param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
throw PaddleMobileError.predictError(message: " encode is nil") let orank = param.output.tensorDim.cout()
} let num = param.input.count
var p = ConcatMetalParam.init() assert(num <= 6)
let odim = (0..<4).map { Int32(param.output.dim[$0]) } var axis = 4 - param.output.tensorDim.cout() + param.axis
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<4 { for i in 0..<4 {
if Int32(param.transpose[i]) == p.axis { if param.transpose[i] == axis {
p.axis = Int32(i) axis = i
break break
} }
} }
for i in 0..<istart { pm.axis = Int32(axis)
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)]) pm.odim = (Int32(param.output.dim[0]), Int32(param.output.dim[1]), Int32(param.output.dim[2]), Int32(param.output.dim[3]))
} pm.trans = (Int32(param.output.transpose[0]), Int32(param.output.transpose[1]), Int32(param.output.transpose[2]), Int32(param.output.transpose[3]))
var vdim: [Int32] = [] var vdim: [Int] = [0, 0, 0, 0, 0, 0]
for i in 0..<(iend - istart) { for i in 0..<num {
encoder.setTexture(param.input[i+istart].metalTexture, index: i) vdim[i] = param.input[i].dim[axis]
vdim.append(Int32(param.input[i+istart].dim[Int(p.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0)
}
p.trans = (Int32(param.transpose[0]), Int32(param.transpose[1]), Int32(param.transpose[2]), Int32(param.transpose[3]))
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
try self.encode(commandBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encode(commandBuffer, param, 6 * group, param.input.count)
}
}
func test(cmdBuffer: MTLCommandBuffer, param: ConcatTestParam) {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
} }
if remain > 0 { if orank == 4 {
self.encodeTest(cmdBuffer, param, 6 * group, param.input.count) if axis == 1 {
v = "y"
} else if axis == 2 {
v = "x"
} else {
if (param.output.dim[0] == 1) && axis == 3 {
var vz = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vz = false
break
}
}
if vz {
v = "z"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
}
} else if orank == 3 {
if axis == 2 {
v = "y"
} else if axis == 3 {
v = "x"
} else if axis == 1 {
var vz = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vz = false
break
}
}
if vz {
v = "z"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
} else {
if axis == 2 {
v = "y"
} else if axis == 3 {
var vx = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vx = false
break
}
}
if vx {
v = "x"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
} }
} pm.vdim = (Int32(vdim[0]), Int32(vdim[1]), Int32(vdim[2]), Int32(vdim[3]), Int32(vdim[4]), Int32(vdim[5]))
required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "concat_\(orank)_float") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_float")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "concat_\(orank)_half") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_half")
} else { } else {
fatalError() fatalError()
} }
......
...@@ -71,10 +71,11 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -71,10 +71,11 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
} }
func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws {
print("reshape compute")
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil") throw PaddleMobileError.predictError(message: " encoder is nil")
} }
encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.output.metalTexture, index: 1)
...@@ -83,15 +84,15 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -83,15 +84,15 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
func test(commandBuffer: MTLCommandBuffer, testParam: ReshapeTestParam) { // func test(commandBuffer: MTLCommandBuffer, testParam: ReshapeTestParam) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { // guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError() // fatalError()
} // }
encoder.setTexture(testParam.inputTexture, index: 0) // encoder.setTexture(testParam.inputTexture, index: 0)
encoder.setTexture(testParam.outputTexture, index: 1) // encoder.setTexture(testParam.outputTexture, index: 1)
var pm: ReshapeMetalParam = testParam.param // var pm: ReshapeMetalParam = testParam.param
encoder.setBytes(&pm, length: MemoryLayout<ReshapeMetalParam>.size, index: 0) // encoder.setBytes(&pm, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: testParam.outputTexture) // encoder.dispatch(computePipline: pipline, outTexture: testParam.outputTexture)
encoder.endEncoding() // encoder.endEncoding()
} // }
} }
...@@ -19,11 +19,12 @@ struct ShapeMetalParam { ...@@ -19,11 +19,12 @@ struct ShapeMetalParam {
class ShapeKernel<P: PrecisionType>: Kernel, Computable{ class ShapeKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { print("shape compute")
throw PaddleMobileError.predictError(message: " encode is nil") // guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
} // throw PaddleMobileError.predictError(message: " encode is nil")
encoder.setTexture(param.output.metalTexture, index: 0) // }
encoder.endEncoding() // encoder.setTexture(param.output.metalTexture, index: 0)
// encoder.endEncoding()
} }
required init(device: MTLDevice, param: ShapeParam<P>) { required init(device: MTLDevice, param: ShapeParam<P>) {
......
...@@ -15,26 +15,76 @@ ...@@ -15,26 +15,76 @@
import Foundation import Foundation
struct SplitMetalParam { struct SplitMetalParam {
var idim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1)
var axis: Int32 = 0
var offset: Int32 = 0
var trans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var vdim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
} }
class SplitKernel<P: PrecisionType>: Kernel, Computable{ class SplitKernel<P: PrecisionType>: Kernel, Computable{
var smp: SplitMetalParam
func compute(commandBuffer: MTLCommandBuffer, param: SplitParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: SplitParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
} }
encoder.setTexture(param.output.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
for i in 0..<param.outputList.count {
encoder.setTexture(param.outputList[i].metalTexture, index: i + 1)
}
encoder.setBytes(&smp, length: MemoryLayout<BoxcoderMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: SplitParam<P>) { required init(device: MTLDevice, param: SplitParam<P>) {
// param.output.initTexture(device: device, computePrecision: computePrecision) // param.output.initTexture(device: device, computePrecision: computePrecision)
let num = param.outputList.count
let rank = param.input.tensorDim.cout()
assert(num >= 2 && num <= 4)
for output in param.outputList { for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
} }
smp = SplitMetalParam.init()
smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3]))
smp.axis = Int32(param.axis + param.input.dim.cout() - param.input.tensorDim.cout())
for i in 0..<4 {
if param.input.transpose[i] == smp.axis {
smp.axis = Int32(i)
break
}
}
smp.trans = (Int32(param.input.transpose[0]), Int32(param.input.transpose[1]), Int32(param.input.transpose[2]), Int32(param.input.transpose[3]))
var vdim: [Int32] = [0, 0, 0, 0]
for i in 0..<num {
vdim[i] = Int32(param.outputList[i].tensorDim[param.axis])
}
smp.vdim = (vdim[0], vdim[1], vdim[2], vdim[3])
var v = "normal"
if rank == 4 {
if smp.axis == 1 {
v = "y"
} else if smp.axis == 2 {
v = "x"
}
} else if rank == 3 {
if smp.axis == 2 {
v = "y"
} else if smp.axis == 3 {
v = "x"
}
} else if rank == 2 {
if smp.axis == 2 {
v = "y"
}
}
if v == "normal" {
fatalError("split unsupported")
}
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_half") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_half")
} else { } else {
fatalError() fatalError()
} }
......
...@@ -23,7 +23,7 @@ struct bilinear_interp_param { ...@@ -23,7 +23,7 @@ struct bilinear_interp_param {
}; };
kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]], kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]], texture2d_array<float, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]], constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
float4 r; float4 r;
...@@ -47,29 +47,29 @@ kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture ...@@ -47,29 +47,29 @@ kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture
output.write(r, gid.xy, gid.z); output.write(r, gid.xy, gid.z);
} }
kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]], //kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]], // texture2d_array<half, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]], // constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { // uint3 gid [[thread_position_in_grid]]) {
//
half4 r; // half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { // if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z); // r = input.read(gid.xy, gid.z);
} else { // } else {
half w = gid.x * pm.ratio_w; // half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h; // half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h; // uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1; // uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0; // half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; // half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0; // if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0; // if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z); // half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z); // half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z); // half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z); // half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); // r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
} // }
output.write(r, gid.xy, gid.z); // output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z); // output.write(r, gid.xy, gid.z);
} //}
...@@ -3,24 +3,52 @@ ...@@ -3,24 +3,52 @@
#define CONCAT2(a, b) a ## b #define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b #define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c #define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, p) CONCAT3_(f, r, p) #define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n) #define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r) #define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(concat, R, P)(texture2d_array<P, access::read> in0 [[texture(0)]], #if V == VX
texture2d_array<P, access::read> in1 [[texture(1)]], #define VV x
texture2d_array<P, access::read> in2 [[texture(2)]], #elif V == VY
texture2d_array<P, access::read> in3 [[texture(3)]], #define VV y
texture2d_array<P, access::read> in4 [[texture(4)]], #elif V == VZ
texture2d_array<P, access::read> in5 [[texture(5)]], #define VV z
texture2d_array<P, access::read> inx [[texture(6)]], #else
texture2d_array<P, access::write> out [[texture(7)]], #define VV normal
constant ConcatParam & pm [[buffer(0)]], #endif
uint3 gid [[thread_position_in_grid]]) {
#if V == VNORMAL
//kernel void FUNC(concat, R, N, normal, P)(array<texture2d_array<P, access::read>, N> in [[texture(0)]],
// texture2d_array<P, access::read> out_x [[texture(N)]],
// texture2d_array<P, access::write> out [[texture(N+1)]],
// constant ConcatParam & pm [[buffer(0)]],
// uint3 gid [[thread_position_in_grid]]) {
//}
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif
texture2d_array<P, access::read> inx [[texture(N)]],
texture2d_array<P, access::write> out [[texture(N+1)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm; ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
VECTOR(P, 4) r; VECTOR(P, 4) r = inx.read(gid.xy, gid.z);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
xyzn[3] = i; xyzn[3] = i;
#if R == 4 #if R == 4
...@@ -29,35 +57,248 @@ kernel void FUNC(concat, R, P)(texture2d_array<P, access::read> in0 [[texture(0) ...@@ -29,35 +57,248 @@ kernel void FUNC(concat, R, P)(texture2d_array<P, access::read> in0 [[texture(0)
FUNC_R(xyzn2abcd, R)(xyzn, abcd); FUNC_R(xyzn2abcd, R)(xyzn, abcd);
#endif #endif
int k = abcd[cp.axis] - cp.offset; int k = abcd[cp.axis] - cp.offset;
if (k < 0) continue;
int j = 0; int j = 0;
if (k < 0) { for (; j < N; j++) {
r[i] = inx.read(gid.xy, gid.z)[i]; if (k < cp.vdim[j]) {
} else { break;
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
} }
int ta = cp.odim[cp.axis]; k -= cp.vdim[j];
abcd[cp.axis] = k; }
cp.odim[cp.axis] = cp.vdim[j]; if (k > cp.vdim[N-1]) {
continue;
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
#if R == 4 #if R == 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn); abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else #else
FUNC_R(abcd2xyzn, R)(abcd, oxyzn); FUNC_R(abcd2xyzn, R)(abcd, oxyzn);
#endif #endif
cp.odim[cp.axis] = ta; cp.odim[cp.axis] = ta;
switch (j) { switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; #if N >= 3
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; #endif
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break; #if N >= 4
} case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 5
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 6
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
} }
} }
out.write(r, gid.xy, gid.z); out.write(r, gid.xy, gid.z);
} }
#endif
#endif // V == NORMAL
#if V == VX
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int x = gid.x - pm.offset;
if (x < 0) return;
if (x < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
x -= pm.vdim[0];
if (x < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
x -= pm.vdim[1];
if (x < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
x -= pm.vdim[2];
if (x < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
x -= pm.vdim[3];
if (x < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
x -= pm.vdim[4];
if (x < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VX
#if V == VY
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int y = gid.y - pm.offset;
if (y < 0) return;
if (y < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
y -= pm.vdim[0];
if (y < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
y -= pm.vdim[1];
if (y < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
y -= pm.vdim[2];
if (y < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
y -= pm.vdim[3];
if (y < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
y -= pm.vdim[4];
if (y < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VY
#if V == VZ
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int z = gid.z - pm.offset;
if (z < 0) return;
if (z < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
z -= pm.vdim[0];
if (z < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
z -= pm.vdim[1];
if (z < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
z -= pm.vdim[2];
if (z < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
z -= pm.vdim[3];
if (z < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
z -= pm.vdim[4];
if (z < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VZ
#undef VV
#endif // #ifdef P
...@@ -25,32 +25,116 @@ struct ConcatParam { ...@@ -25,32 +25,116 @@ struct ConcatParam {
int32_t vdim[6]; int32_t vdim[6];
}; };
#define P float #define VNORMAL 1
#define R 4 #define VX 2
#include "ConcatKernel.inc.metal" #define VY 3
#undef R #define VZ 4
#define R 3
#include "ConcatKernel.inc.metal" // >> fast mode
#undef R // only support concat_{2,3,4}_{2,3,4,5,6}_y_{float,half}
#define R 2 // only support concat_{3,4}_{2,3,4,5,6}_x_{float,half}
#include "ConcatKernel.inc.metal" // only support concat_{1,2,3,4}_{2,3,4,5,6}_z_{float,half}
#undef R // >> normal mode (loop mode)
#define R 1 // ssd-ar: (R=4, N=3, V=z), (R=3, N=2, V=y), (R=2, N=5, V=x), (R=3, N=5, V=x)
#include "ConcatKernel.inc.metal" // ssd: (R=2, N=6, V=y), (R=3, N=6, V=y)
#undef R // genet: (R=4, N=2, V=normal)
#undef P
// ssd-ar: (R=3, N=5, V=x)
#define P half #define V VX
#define R 4 #define R 3
#include "ConcatKernel.inc.metal" #define N 5
#undef R #define P float
#define R 3 #include "ConcatKernel.inc.metal"
#include "ConcatKernel.inc.metal" #undef P
#undef R #define P half
#define R 2 #include "ConcatKernel.inc.metal"
#include "ConcatKernel.inc.metal" #undef P
#undef R #undef N
#define R 1 #undef R
#include "ConcatKernel.inc.metal" #undef V
#undef R
#undef P // ssd-ar: (R=2, N=5, V=x)
#define V VX
#define R 2
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=3, N=2, V=y)
#define V VY
#define R 3
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=4, N=3, V=z)
#define V VZ
#define R 4
#define N 3
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=2, N=6, V=y)
#define V VY
#define R 2
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=3, N=6, V=y)
#define V VY
#define R 3
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VNORMAL
#define R 4
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(split, R, N, V, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
#if V == y
int y = gid.y - sp.offset;
if (y < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
y -= sp.vdim[0];
if (y < sp.vdim[1]) {
out2.write(r, uint2(gid.x, y), gid.z);
} else {
#if N >= 3
y -= sp.vdim[1];
if (y < sp.vdim[2]) {
out3.write(r, uint2(gid.x, y), gid.z);
} else {
#if N >= 4
y -= sp.vdim[2];
if (y < sp.vdim[3]) {
out4.write(r, uint2(gid.x, y), gid.z);
}
#endif
}
#endif
}
}
#elif V == x
int x = gid.x;
if (x < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
x -= sp.vdim[0];
if (x < sp.vdim[1]) {
out2.write(r, uint2(x, gid.y), gid.z);
} else {
#if N >= 3
x -= sp.vdim[1];
if (x < sp.vdim[2]) {
out3.write(r, uint2(x, gid.y), gid.z);
} else {
#if N >= 4
x -= sp.vdim[2];
if (x < sp.vdim[3]) {
out4.write(r, uint2(x, gid.y), gid.z);
}
#endif
}
#endif
}
}
#else
#endif
}
#endif
...@@ -13,18 +13,60 @@ ...@@ -13,18 +13,60 @@
limitations under the License. */ limitations under the License. */
#include <metal_stdlib> #include <metal_stdlib>
#include "Common.metal"
using namespace metal; using namespace metal;
kernel void split(texture2d_array<float, access::write> output[[texture(0)]], struct SplitParam {
uint3 gid [[thread_position_in_grid]]) { int32_t idim[4];
float4 r; int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[4];
};
// only support split_{2, 3, 4}_{2, 3, 4}_y_{float, half}
// only support split_{3, 4}_{2, 3, 4}_x_{float, half}
#define V y
// for R in 2..4
#define R 3
// for N in 2..4
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
// end for N
#undef R
// end for R
#undef V
#define V x
// for R in 3..4
#define R 3
// for N in 2..4
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
// end for N
output.write(r, gid.xy, gid.z); #undef R
} // end for R
#undef V
kernel void split_half(texture2d_array<half, access::write> output[[texture(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
output.write(half4(r), gid.xy, gid.z);
}
...@@ -16,7 +16,7 @@ import Foundation ...@@ -16,7 +16,7 @@ import Foundation
class ScaleKernel: CusomKernel { class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) { init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
} }
} }
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import Foundation import Foundation
let testTo = 3 let testTo = 114
var isTest = false var isTest = false
let computePrecision: ComputePrecision = .Float16 let computePrecision: ComputePrecision = .Float32
public class ResultHolder { public class ResultHolder {
public let dim: [Int] public let dim: [Int]
...@@ -101,7 +101,7 @@ public class Executor<P: PrecisionType> { ...@@ -101,7 +101,7 @@ public class Executor<P: PrecisionType> {
let inputTexture = InputTexture.init(inMTLTexture: resInput, inExpectDim: Dim.init(inDim: dim)) let inputTexture = InputTexture.init(inMTLTexture: resInput, inExpectDim: Dim.init(inDim: dim))
program.scope.setInput(input: inputTexture) program.scope.setInput(input: inputTexture)
//(ops.count - except) //(ops.count - except)
for i in 0..<ops.count { for i in 0..<testTo {
let op = ops[i] let op = ops[i]
do { do {
try op.run(device: device, buffer: buffer) try op.run(device: device, buffer: buffer)
...@@ -112,35 +112,35 @@ public class Executor<P: PrecisionType> { ...@@ -112,35 +112,35 @@ public class Executor<P: PrecisionType> {
var outputTextures: [String : [Variant]]? var outputTextures: [String : [Variant]]?
if except > 0 { if except > 0 {
outputTextures = ops[ops.count - except].inputVariant() outputTextures = ops[testTo-1].inputVariant()
} }
buffer.addCompletedHandler { [weak self] (commandbuffer) in buffer.addCompletedHandler { [weak self] (commandbuffer) in
// let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2])) let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2]))
//// print(inputArr.strideArray()) print(inputArr.strideArray())
// print(dim) // print(dim)
// writeToLibrary(fileName: "test_image_ssd_ar", array: inputArr) // writeToLibrary(fileName: "test_image_ssd_ar", array: inputArr)
//
// print("write to library done") // print("write to library done")
// return // return
// print(inputArr) // print(inputArr)
//
// let stridableInput: [(index: Int, value: Float)] = input.stridableFloatArray() // let stridableInput: [(index: Int, value: Float)] = input.stridableFloatArray()
// print(stridableInput) // print(stridableInput)
//
// let _: Flo? = input.logDesc(header: "input: ", stridable: true) // let _: Flo? = input.logDesc(header: "input: ", stridable: true)
// for i in 0..<self.ops.count { for i in 0..<testTo {
// let op = self.ops[i] let op = self!.ops[i]
// print(" 第 \(i) 个 op: ") print(" 第 \(i) 个 op: ")
// op.delogOutput() op.delogOutput()
// } }
// return; // return;
// self.ops[testTo - 2].delogOutput() // self!.ops[testTo - 2].delogOutput()
// self.ops[testTo - 1].delogOutput() // self!.ops[testTo - 1].delogOutput()
// self.ops[60].delogOutput() // self!.ops[60].delogOutput()
// return // return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册