提交 0d0de279 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #893 from codeWorm2015/metal

commit before test
...@@ -31,7 +31,7 @@ class Genet: Net { ...@@ -31,7 +31,7 @@ class Genet: Net {
} }
func resultStr(res: [Float]) -> String { func resultStr(res: [Float]) -> String {
return " 哈哈 还没好 genet !"; return " \(res) "
} }
var preprocessKernel: CusomKernel var preprocessKernel: CusomKernel
......
...@@ -79,7 +79,7 @@ class ViewController: UIViewController { ...@@ -79,7 +79,7 @@ class ViewController: UIViewController {
return return
} }
do { do {
let max = 1 let max = 50
let startDate = Date.init() let startDate = Date.init()
for i in 0..<max { for i in 0..<max {
try net.predict(inTexture: inTexture) { [weak self] (result) in try net.predict(inTexture: inTexture) { [weak self] (result) in
...@@ -87,7 +87,6 @@ class ViewController: UIViewController { ...@@ -87,7 +87,6 @@ class ViewController: UIViewController {
fatalError() fatalError()
} }
print(result.resultArray.strideArray())
if i == max - 1 { if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate) let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async { DispatchQueue.main.async {
......
...@@ -126,11 +126,11 @@ public class Executor<P: PrecisionType> { ...@@ -126,11 +126,11 @@ public class Executor<P: PrecisionType> {
// print(stridableInput) // print(stridableInput)
// let _: Flo? = input.logDesc(header: "input: ", stridable: true) // let _: Flo? = input.logDesc(header: "input: ", stridable: true)
for i in 0..<self.ops.count { // for i in 0..<self.ops.count {
let op = self.ops[i] // let op = self.ops[i]
print(" 第 \(i) 个 op: ") // print(" 第 \(i) 个 op: ")
op.delogOutput() // op.delogOutput()
} // }
// return; // return;
// self.ops[testTo - 2].delogOutput() // self.ops[testTo - 2].delogOutput()
......
...@@ -43,9 +43,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -43,9 +43,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
N: Int32(param.input.tensorDim[0]), N: Int32(param.input.tensorDim[0]),
K: Int32(param.input.tensorDim[1]) K: Int32(param.input.tensorDim[1])
) )
print(" soft max param: ")
print(smp)
encoder.setBytes(&smp, length: MemoryLayout<SoftmaxMetalParam>.size, index: 0) encoder.setBytes(&smp, length: MemoryLayout<SoftmaxMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
......
...@@ -50,7 +50,7 @@ public class Texture<P: PrecisionType>: Tensorial { ...@@ -50,7 +50,7 @@ public class Texture<P: PrecisionType>: Tensorial {
guard padToFourDim.cout() == 4 else { guard padToFourDim.cout() == 4 else {
fatalError("- not support -") fatalError("- not support -")
} }
return metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3])) return metalTexture.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2]))
} }
func realNHWC() -> [Float32] { func realNHWC() -> [Float32] {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册