From 9315073fbf15e5c5467c81778ac7c01115678429 Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Tue, 30 Apr 2019 18:20:23 +0800 Subject: [PATCH] enable all mps conv layer (#1592) --- .../paddle-mobile/Src/Framework/Texture.swift | 12 ++++++++++-- .../Src/Operators/Kernels/ConvAddKernel.swift | 5 +---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift index 6b36680288..c99b545f39 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Framework/Texture.swift @@ -135,9 +135,17 @@ public class Texture: Tensorial { } if computePrecision == .Float16 { - tmpTextureDes.pixelFormat = .rgba16Float + if tensorDim[1] == 1 { + tmpTextureDes.pixelFormat = .r16Float + } else { + tmpTextureDes.pixelFormat = .rgba16Float + } } else if computePrecision == .Float32 { - tmpTextureDes.pixelFormat = .rgba32Float + if tensorDim[1] == 1 { + tmpTextureDes.pixelFormat = .r32Float + } else { + tmpTextureDes.pixelFormat = .rgba32Float + } } tmpTextureDes.usage = [.shaderRead, .shaderWrite] diff --git a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift index 44ecbade3a..9acfc2a453 100644 --- a/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Src/Operators/Kernels/ConvAddKernel.swift @@ -111,10 +111,7 @@ class ConvAddKernel: Kernel, Computable { var shouldUseMPS = false if #available(iOS 11.0, *), initContext.useMPS { - // 输入输出 tensor channel 必须都大于 4 - if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 { - shouldUseMPS = true - } + shouldUseMPS = true } if shouldUseMPS { -- GitLab