提交 9315073f 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enable all mps conv layer (#1592)

上级 be3cdaaf
......@@ -135,10 +135,18 @@ public class Texture: Tensorial {
}
if computePrecision == .Float16 {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r16Float
} else {
tmpTextureDes.pixelFormat = .rgba16Float
}
} else if computePrecision == .Float32 {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r32Float
} else {
tmpTextureDes.pixelFormat = .rgba32Float
}
}
tmpTextureDes.usage = [.shaderRead, .shaderWrite]
tmpTextureDes.storageMode = .shared
......
......@@ -111,11 +111,8 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var shouldUseMPS = false
if #available(iOS 11.0, *), initContext.useMPS {
// 输入输出 tensor channel 必须都大于 4
if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
shouldUseMPS = true
}
}
if shouldUseMPS {
super.init(device: device, inFunctionName: nil, initContext: initContext)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册