提交 5bbddc28 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #839 from codeWorm2015/metal

update
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import Foundation import Foundation
let testTo = 41
public class ResultHolder<P: PrecisionType> { public class ResultHolder<P: PrecisionType> {
public let dim: [Int] public let dim: [Int]
public let resultArr: [P] public let resultArr: [P]
...@@ -60,7 +62,7 @@ public class Executor<P: PrecisionType> { ...@@ -60,7 +62,7 @@ public class Executor<P: PrecisionType> {
queue = inQueue queue = inQueue
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
//block.ops.count //block.ops.count
for i in 0..<91 { for i in 0..<(testTo + 1) {
let op = block.ops[i] let op = block.ops[i]
do { do {
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
...@@ -128,7 +130,7 @@ public class Executor<P: PrecisionType> { ...@@ -128,7 +130,7 @@ public class Executor<P: PrecisionType> {
} }
// return // return
self.ops[90].delogOutput() self.ops[testTo].delogOutput()
// self.ops[91].delogOutput() // self.ops[91].delogOutput()
// self.ops[92].delogOutput() // self.ops[92].delogOutput()
// self.ops[93].delogOutput() // self.ops[93].delogOutput()
......
...@@ -63,7 +63,10 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run ...@@ -63,7 +63,10 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run
return o return o
} }
print(outputArray.strideArray()) // print(outputArray.strideArray())
writeToLibrary(fileName: "concat_out", array: outputArray)
let device: MTLDevice = MTLCreateSystemDefaultDevice()! let device: MTLDevice = MTLCreateSystemDefaultDevice()!
// let tensorArray: [P] = device.texture2tensor(texture: para.output.metalTexture, dim: [1917, 4]) // let tensorArray: [P] = device.texture2tensor(texture: para.output.metalTexture, dim: [1917, 4])
......
...@@ -47,25 +47,62 @@ kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)] ...@@ -47,25 +47,62 @@ kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]
invtrans(lrp.otrans, oabcd, tabcd); invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd); int index = abcd2index(lrp.odim, tabcd);
if (index < count) { if (index < count) {
// int c = index % 4; int c = index % 4;
//
// int temp0 = index % (inTexture.get_array_size() * 4); int temp0 = index % (inTexture.get_array_size() * 4);
// int slice = temp0 / 4; int slice = temp0 / 4;
//
// int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]); int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]);
// int w = temp1 / (inTexture.get_array_size() * 4); int w = temp1 / (inTexture.get_array_size() * 4);
//
// int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]); int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]);
index2abcd(lrp.idim, index, tabcd); // index2abcd(lrp.idim, index, tabcd);
abcd2xyzn(iC, tabcd, ixyzn); // abcd2xyzn(iC, tabcd, ixyzn);
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; r[n] = inTexture.read(uint2(w, h), slice)[c];
} else { } else {
r[n] = 0; r[n] = 0;
} }
} }
outTexture.write(r, gid.xy, gid.z); outTexture.write(r, gid.xy, gid.z);
} }
/*
kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4];
ReshapeParam lrp = rp;
int oC = lrp.odim[lrp.otrans[3]];
int iC = lrp.idim[lrp.itrans[3]];
int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3];
float4 r;
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
xyzn2abcd(oC, oxyzn, oabcd);
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd);
if (index < count) {
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, tabcd, ixyzn);
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
r[n] = 0;
}
}
outTexture.write(r, gid.xy, gid.z);
}
*/
// //
//kernel void reshape_half(texture2d_array<half, access::read> inTexture [[texture(0)]], //kernel void reshape_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
// texture2d_array<half, access::write> outTexture [[texture(1)]], // texture2d_array<half, access::write> outTexture [[texture(1)]],
......
...@@ -55,5 +55,8 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, ...@@ -55,5 +55,8 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>,
// let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false)
let _: P? = para.output.metalTexture.logDesc(header: "reshape output: ", stridable: true) let _: P? = para.output.metalTexture.logDesc(header: "reshape output: ", stridable: true)
let deivice = MTLCreateSystemDefaultDevice()!
let array: [Float32] = deivice.texture2tensor(texture: para.output.metalTexture, dim: [1, 1, 600, 7])
print(array.strideArray())
} }
} }
...@@ -60,6 +60,7 @@ class TransposeOp<P: PrecisionType>: Operator<TransposeKernel<P>, TransposeParam ...@@ -60,6 +60,7 @@ class TransposeOp<P: PrecisionType>: Operator<TransposeKernel<P>, TransposeParam
} }
print(outputArray.strideArray()) print(outputArray.strideArray())
// writeToLibrary(fileName: "transpose_ouput", array: outputArray) // writeToLibrary(fileName: "transpose_ouput", array: outputArray)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册