提交 88d033ed 编写于 作者: D dolphin8

fix

上级 0bb67049
...@@ -27,6 +27,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -27,6 +27,7 @@ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
} }
required init(device: MTLDevice, param: SoftmaxParam<P>) { required init(device: MTLDevice, param: SoftmaxParam<P>) {
param.output.initTexture(device: device)
super.init(device: device, inFunctionName: "softmax") super.init(device: device, inFunctionName: "softmax")
} }
} }
...@@ -49,7 +49,11 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable, Testable { ...@@ -49,7 +49,11 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable, Testable {
for (i, v) in param.input.transpose.enumerated() { for (i, v) in param.input.transpose.enumerated() {
invT[v] = i invT[v] = i
} }
let realAxis = param.axis.map {invT[$0]} var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count {
axis[4-param.axis.count+i] = 4 - param.axis.count + Int(param.axis[i])
}
let realAxis = axis.map {invT[$0]}
var tmp = TransposeMetalParam.init(realAxis) var tmp = TransposeMetalParam.init(realAxis)
tmp.iC = Int32(param.input.dim[param.input.transpose[3]]) tmp.iC = Int32(param.input.dim[param.input.transpose[3]])
tmp.oC = Int32(param.output.dim[3]) tmp.oC = Int32(param.output.dim[3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册