提交 c3c6cbd3 编写于 作者: Y Yanzhan Yang 提交者: GitHub

change pixelFormat only when useAggressiveOptimization is set (#1614)

上级 0ee40193
......@@ -71,6 +71,7 @@ extension InputTexture {
public class Texture: Tensorial {
public var dim: Dim
public var tensorDim: Dim
public var useMPS = false
/// tensor dim pad to four
public var padToFourDim: Dim
......@@ -135,14 +136,22 @@ public class Texture: Tensorial {
}
if computePrecision == .Float16 {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r16Float
if useMPS {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r16Float
} else {
tmpTextureDes.pixelFormat = .rgba16Float
}
} else {
tmpTextureDes.pixelFormat = .rgba16Float
}
} else if computePrecision == .Float32 {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r32Float
if useMPS {
if tensorDim[1] == 1 {
tmpTextureDes.pixelFormat = .r32Float
} else {
tmpTextureDes.pixelFormat = .rgba32Float
}
} else {
tmpTextureDes.pixelFormat = .rgba32Float
}
......
......@@ -112,8 +112,14 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var shouldUseMPS = false
let functionName = type(of: self).kernelFunctionName(param: param, useAggressiveOptimization: initContext.useAggresiveOptimization)
if #available(iOS 11.0, *), (initContext.useMPS || initContext.useAggresiveOptimization) {
if (param.input.tensorDim[1] == 1 || param.input.tensorDim[1] > 4) && (param.output.tensorDim[1] == 1 || param.output.tensorDim[1] > 4) {
shouldUseMPS = true
if initContext.useAggresiveOptimization {
if (param.input.tensorDim[1] == 1 || param.input.tensorDim[1] > 4) && (param.output.tensorDim[1] == 1 || param.output.tensorDim[1] > 4) {
shouldUseMPS = true
}
} else {
if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
shouldUseMPS = true
}
}
}
if type(of: self).isWinoGrad(functionName: functionName) {
......@@ -166,6 +172,8 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if #available(iOS 11.0, *) {
param.input.useMPS = true
param.output.useMPS = true
let desc: MPSCNNConvolutionDescriptor = isDepthWise ?
MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册