提交 10944966 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #834 from dolphin8/metal

fix
......@@ -27,6 +27,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
}
required init(device: MTLDevice, param: SoftmaxParam<P>) {
param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "softmax")
}
}
......@@ -49,7 +49,11 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable, Testable {
for (i, v) in param.input.transpose.enumerated() {
invT[v] = i
}
let realAxis = param.axis.map {invT[$0]}
var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count {
axis[4-param.axis.count+i] = 4 - param.axis.count + Int(param.axis[i])
}
let realAxis = axis.map {invT[$0]}
var tmp = TransposeMetalParam.init(realAxis)
tmp.iC = Int32(param.input.dim[param.input.transpose[3]])
tmp.oC = Int32(param.output.dim[3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册