diff --git a/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift index 6b36680288a50a6c847f682f8a609595cef0ffd5..c99b545f396ad2dd721c5e01a217517223fbbe12 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift @@ -135,9 +135,17 @@ public class Texture: Tensorial { } if computePrecision == .Float16 { - tmpTextureDes.pixelFormat = .rgba16Float + if tensorDim[1] == 1 { + tmpTextureDes.pixelFormat = .r16Float + } else { + tmpTextureDes.pixelFormat = .rgba16Float + } } else if computePrecision == .Float32 { - tmpTextureDes.pixelFormat = .rgba32Float + if tensorDim[1] == 1 { + tmpTextureDes.pixelFormat = .r32Float + } else { + tmpTextureDes.pixelFormat = .rgba32Float + } } tmpTextureDes.usage = [.shaderRead, .shaderWrite] diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift index 44ecbade3ad18775deede5ef9ed8b4e9f79dc9e9..9acfc2a453aa6cfd3e62bedc70f1afa05759c30c 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift @@ -111,10 +111,7 @@ class ConvAddKernel: Kernel, Computable { var shouldUseMPS = false if #available(iOS 11.0, *), initContext.useMPS { - // 输入输出 tensor channel 必须都大于 4 - if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 { - shouldUseMPS = true - } + shouldUseMPS = true } if shouldUseMPS {