未验证 提交 ac8b7b51 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #838 from codeWorm2015/metal

 commit
...@@ -98,14 +98,14 @@ extension Array where Element: Comparable{ ...@@ -98,14 +98,14 @@ extension Array where Element: Comparable{
} }
extension Array { extension Array {
func strideArray(inCount: Int = 20) -> Array<Element> { func strideArray(inCount: Int = 20) -> [(Int, Element)] {
if count < inCount { if count < inCount {
return self return (0..<count).map{ ($0, self[$0]) }
} else { } else {
let stride = count / inCount let stride = count / inCount
var newArray: [Element] = [] var newArray: [(Int, Element)] = []
for i in 0..<inCount { for i in 0..<inCount {
newArray.append(self[i * stride]) newArray.append((i * stride, self[i * stride]))
} }
return newArray return newArray
} }
......
...@@ -60,7 +60,7 @@ public class Executor<P: PrecisionType> { ...@@ -60,7 +60,7 @@ public class Executor<P: PrecisionType> {
queue = inQueue queue = inQueue
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
//block.ops.count //block.ops.count
for i in 0..<39 { for i in 0..<91 {
let op = block.ops[i] let op = block.ops[i]
do { do {
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
...@@ -128,7 +128,7 @@ public class Executor<P: PrecisionType> { ...@@ -128,7 +128,7 @@ public class Executor<P: PrecisionType> {
} }
// return // return
self.ops[38].delogOutput() self.ops[90].delogOutput()
// self.ops[91].delogOutput() // self.ops[91].delogOutput()
// self.ops[92].delogOutput() // self.ops[92].delogOutput()
// self.ops[93].delogOutput() // self.ops[93].delogOutput()
......
...@@ -62,6 +62,7 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run ...@@ -62,6 +62,7 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run
let outputArray = para.output.metalTexture.floatArray { (o: Float32) -> Float32 in let outputArray = para.output.metalTexture.floatArray { (o: Float32) -> Float32 in
return o return o
} }
print(outputArray.strideArray()) print(outputArray.strideArray())
let device: MTLDevice = MTLCreateSystemDefaultDevice()! let device: MTLDevice = MTLCreateSystemDefaultDevice()!
......
...@@ -25,7 +25,7 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -25,7 +25,7 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
super.init(device: device, inFunctionName: "conv_add_3x3") super.init(device: device, inFunctionName: "conv_add_3x3")
} }
param.output.initTexture(device: device, inTranspose: [0, 3, 2, 1]) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1])
let offsetX = param.filter.width/2 - Int(param.paddings[0]) let offsetX = param.filter.width/2 - Int(param.paddings[0])
let offsetY = param.filter.height/2 - Int(param.paddings[1]) let offsetY = param.filter.height/2 - Int(param.paddings[1])
......
...@@ -47,20 +47,19 @@ kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)] ...@@ -47,20 +47,19 @@ kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]
invtrans(lrp.otrans, oabcd, tabcd); invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd); int index = abcd2index(lrp.odim, tabcd);
if (index < count) { if (index < count) {
int c = index % 4; // int c = index % 4;
//
int temp0 = index % (inTexture.get_array_size() * 4); // int temp0 = index % (inTexture.get_array_size() * 4);
int slice = temp0 / 4; // int slice = temp0 / 4;
//
int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]); // int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]);
int w = temp1 / (inTexture.get_array_size() * 4); // int w = temp1 / (inTexture.get_array_size() * 4);
//
int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]); // int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]);
// index2abcd(lrp.idim, index, tabcd); index2abcd(lrp.idim, index, tabcd);
// abcd2xyzn(iC, tabcd, ixyzn); abcd2xyzn(iC, tabcd, ixyzn);
r[n] = inTexture.read(uint2(w, h), slice)[c]; r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else { } else {
r[n] = 0; r[n] = 0;
} }
......
...@@ -53,6 +53,7 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, ...@@ -53,6 +53,7 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>,
func delogOutput() { func delogOutput() {
print("reshape delog") print("reshape delog")
// let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false)
let _: P? = para.output.metalTexture.logDesc(header: "reshape output: ", stridable: false) let _: P? = para.output.metalTexture.logDesc(header: "reshape output: ", stridable: true)
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册