未验证 提交 716326c5 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #820 from dolphin8/metal

concat
......@@ -20,7 +20,8 @@ class ViewController: UIViewController {
inDevice: device,
inQueue: queue
)
test.testReshape()
test.testConcat()
// test.testReshape()
// test.testTranspose()
print(" done ")
}
......
......@@ -82,6 +82,37 @@ public class PaddleMobileUnitTest {
indentPrintTensor(tensor: tensor, dim: ndim, ix: dim.map { $0 * 0 }, indentLevel: 0)
}
public func testConcat() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
var it: [[Float32]] = []
for _ in 0..<7 {
it.append((0..<12).map { Float32($0) })
}
let input = it.map { device.tensor2texture(value: $0, dim: [3, 4]) }
let output = device.tensor2texture(value: [Float32](), dim: [3, 28])
let param = ConcatTestParam.init(
input: input,
output: output,
dims: [[3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]],
axis: 1,
odim: [3, 28]
)
let concatKernel = ConcatKernel<Float32>.init(device: device, testParam: param)
concatKernel.test(cmdBuffer: buffer, param: param)
buffer.addCompletedHandler { (buffer) in
for i in 0..<it.count {
let _: Float32? = input[i].logDesc()
self.tensorPrint(tensor: it[i], dim: [3, 4])
}
let _: Float32? = output.logDesc()
let tx: [Float32] = self.device.texture2tensor(texture: output, dim: [3, 28])
self.tensorPrint(tensor: tx, dim: [3, 28])
}
buffer.commit()
}
public func testReshape() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
// let input: [Float32] = (0..<24).map { Float32($0) }
......
......@@ -41,8 +41,8 @@ class ConcatParam<P: PrecisionType>: OpParam {
class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
let dim = para.input.reduce([0, 0]) {[$0[0] + $1.dim[0], $1.dim[1]]}
para.output.dim = Dim.init(inDim: dim)
// let dim = para.input.reduce([0, 0]) {[$0[0] + $1.dim[0], $1.dim[1]]}
// para.output.dim = Dim.init(inDim: dim)
}
typealias OpType = ConcatOp<P>
......
......@@ -14,18 +14,114 @@
import Foundation
struct ConcatTestParam: TestParam {
var input: [MTLTexture]
var output: MTLTexture
var dims: [[Int]]
var axis: Int
var odim: [Int]
}
struct ConcatMetalParam {
var odim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1)
var axis: Int32 = 0
var offset: Int32 = 0
var vdim: (Int32, Int32, Int32, Int32, Int32, Int32) = (0, 0, 0, 0, 0, 0)
}
class ConcatKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
func encodeTest(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatTestParam, _ istart: Int, _ iend: Int) {
let encoder = cmdBuffer.makeComputeCommandEncoder()!
var p = ConcatMetalParam.init()
var odim: [Int32] = [1, 1, 1, 1]
for i in 0..<param.odim.count {
odim[4-param.odim.count+i] = Int32(param.odim[i])
}
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.odim.count + param.axis)
for i in 0..<istart {
p.offset += Int32(param.dims[i][param.axis])
}
var vdim: [Int32] = []
for i in 0..<(iend - istart) {
encoder.setTexture(param.input[i+istart], index: i)
vdim.append(Int32(param.dims[i+istart][Int(param.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0], index: i)
vdim.append(0)
}
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output, index: 6)
encoder.setTexture(param.output, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding()
}
func encode(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatParam<P>, _ istart: Int, _ iend: Int) throws {
guard let encoder = cmdBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
// encoder.setTexture(param.input.metalTexture, index: 0)
// encoder.setTexture(param.output.metalTexture, index: 1)
var p = ConcatMetalParam.init()
let odim = (0..<4).map { Int32(param.output.dim[$0]) }
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<istart {
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)])
}
var vdim: [Int32] = []
for i in 0..<(iend - istart) {
encoder.setTexture(param.input[i+istart].metalTexture, index: i)
vdim.append(Int32(param.input[i+istart].dim[Int(p.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0)
}
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
for i in 0..<param.input.count {
for j in 0..<4 {
assert(param.input[i].transpose[j] == j)
}
}
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
try self.encode(commandBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encode(commandBuffer, param, 6 * group, param.input.count)
}
}
func test(cmdBuffer: MTLCommandBuffer, param: ConcatTestParam) {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
try self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encodeTest(cmdBuffer, param, 6 * group, param.input.count)
}
}
required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "concat")
}
required init(device: MTLDevice, testParam: ConcatTestParam) {
super.init(device: device, inFunctionName: "concat")
}
}
......@@ -425,3 +425,55 @@ kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]
// half4 r = inTexture.read(uint2(0, 0), gid.x);
// outTexture.write(r, gid.xy, gid.z);
//}
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
......@@ -40,6 +40,7 @@ extension InputTexture {
public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim
var tensorDim: Dim
private(set) var originDim: Dim
private var textureDesc: MTLTextureDescriptor!
var metalTexture: MTLTexture!
......@@ -89,7 +90,7 @@ public class Texture<P: PrecisionType>: Tensorial {
} else {
fatalError(" not support ")
}
tensorDim = inDim
dim = fourDim
originDim = fourDim
layout = DataLayout.init([(.N, fourDim[0]), (.C, fourDim[1]), (.H, fourDim[2]), (.W, fourDim[3])])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册