From c3c7b07d908c1ca673fec1d0c5373258a41c2a66 Mon Sep 17 00:00:00 2001 From: liuruilong Date: Thu, 26 Jul 2018 17:28:41 +0800 Subject: [PATCH] add unit test --- .../project.pbxproj | 8 + .../paddle-mobile-demo/MetalHelper.swift | 23 +++ .../paddle-mobile-demo/PreProcessKernel.metal | 12 +- .../paddle-mobile-demo/ViewController.swift | 43 ++-- .../paddle-mobile.xcodeproj/project.pbxproj | 8 + .../paddle-mobile/Common/Extensions.swift | 3 +- .../paddle-mobile/Common/MetalExtension.swift | 186 ++++++++++++++---- .../Common/PaddleMobileUnitTest.swift | 149 ++++++++++++++ .../paddle-mobile/Common/Tools.swift | 21 ++ .../paddle-mobile/Executor.swift | 38 ++-- .../Operators/Base/Operator.swift | 3 + .../Operators/ConvAddBatchNormReluOp.swift | 11 +- .../paddle-mobile/Operators/FeedOp.swift | 2 +- .../paddle-mobile/Operators/FetchOp.swift | 4 +- .../Kernels/ConvAddBatchNormReluKernel.swift | 54 ++++- .../Operators/Kernels/ConvKernel.metal | 17 +- .../Operators/Kernels/ConvKernel.swift | 2 +- .../Operators/Kernels/Kernel.swift | 10 + .../Operators/Kernels/Kernels.metal | 13 +- .../paddle-mobile/framework/Dim.swift | 3 +- .../paddle-mobile/framework/Texture.swift | 1 + 21 files changed, 514 insertions(+), 97 deletions(-) create mode 100644 metal/paddle-mobile-demo/paddle-mobile-demo/MetalHelper.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Common/PaddleMobileUnitTest.swift create mode 100644 metal/paddle-mobile/paddle-mobile/Common/Tools.swift diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj index e5a0e3bc21..e224d50427 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/project.pbxproj @@ -215,10 +215,12 @@ FC0E2DB420EDC03C009C1FAC /* conv2d_27.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */; }; FC0E2DB520EDC03C009C1FAC /* conv2d_33.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */; }; FC0E2DB620EDC03C009C1FAC /* depthwise_conv2d_7.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */; }; + FC3602C82108580600FACB58 /* MetalHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602C72108580600FACB58 /* MetalHelper.swift */; }; FCD04E6320F3146B0007374F /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6120F3146A0007374F /* params */; }; FCD04E6420F3146B0007374F /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6220F3146A0007374F /* model */; }; FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; }; FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; + FCEEE7D4210627A000444BEC /* banana.jpeg in Resources */ = {isa = PBXBuildFile; fileRef = FCEEE7D3210627A000444BEC /* banana.jpeg */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -448,9 +450,11 @@ FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_27.w_0; sourceTree = ""; }; FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_33.w_0; sourceTree = ""; }; FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = depthwise_conv2d_7.w_0; sourceTree = ""; }; + FC3602C72108580600FACB58 /* MetalHelper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalHelper.swift; sourceTree = ""; }; FCD04E6120F3146A0007374F /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = ""; }; FCD04E6220F3146A0007374F /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = ""; }; FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + FCEEE7D3210627A000444BEC /* banana.jpeg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = banana.jpeg; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -514,6 +518,7 @@ FC039B8820E11C560081E9F8 /* Assets.xcassets */, FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */, FC039B8D20E11C560081E9F8 /* Info.plist */, + FC3602C72108580600FACB58 /* MetalHelper.swift */, ); path = "paddle-mobile-demo"; sourceTree = ""; @@ -521,6 +526,7 @@ FC0E2C1D20EDC030009C1FAC /* images */ = { isa = PBXGroup; children = ( + FCEEE7D3210627A000444BEC /* banana.jpeg */, FC0E2C1E20EDC030009C1FAC /* apple.jpg */, ); name = images; @@ -900,6 +906,7 @@ FC0E2D1120EDC03B009C1FAC /* conv2d_10.w_0 in Resources */, FC0E2D7120EDC03C009C1FAC /* conv2d_2.w_0 in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, + FCEEE7D4210627A000444BEC /* banana.jpeg in Resources */, FC0E2D5020EDC03C009C1FAC /* batch_norm_31.w_1 in Resources */, FC0E2D2B20EDC03B009C1FAC /* batch_norm_34.w_1 in Resources */, FC0E2D8F20EDC03C009C1FAC /* conv2d_20.w_0 in Resources */, @@ -1072,6 +1079,7 @@ FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */, FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */, + FC3602C82108580600FACB58 /* MetalHelper.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/MetalHelper.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/MetalHelper.swift new file mode 100644 index 0000000000..9a242f5617 --- /dev/null +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/MetalHelper.swift @@ -0,0 +1,23 @@ +// +// MetalHelper.swift +// paddle-mobile-demo +// +// Created by liuRuiLong on 2018/7/25. +// Copyright © 2018年 orange. All rights reserved. +// + +import Metal +import paddle_mobile +import Foundation + + +class MetalHelper { + let device: MTLDevice + let queue: MTLCommandQueue + static let shared: MetalHelper = MetalHelper.init() + private init(){ + device = MTLCreateSystemDefaultDevice()! + queue = device.makeCommandQueue()! + } +} + diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal b/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal index f32ab30d6b..a9ab6430d6 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/PreProcessKernel.metal @@ -19,10 +19,12 @@ kernel void preprocess( gid.y >= outTexture.get_height()) { return; } - // Subtract mean values, scale by 0.017, convert to BGR. - - const auto means = float4(103.94f, 116.78f, 123.68f, 0.0f); - const float4 inColor = (float4(inTexture.read(gid)) * 255.0f - means) * 0.017f; - outTexture.write(float4(inColor.x, inColor.y, inColor.z, 0.0f), gid); + const auto means = float4(123.68f, 116.78f, 103.94f, 0.0f); + const float4 inColor = (float4(float4(inTexture.read(gid))) * 255.0f - means) * 0.017f; + outTexture.write(float4(inColor.z, inColor.y, inColor.x, 0.0f), gid); } + + + + diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift index 74ad1c7a51..5e96655e76 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift @@ -17,6 +17,8 @@ import MetalKit import paddle_mobile import MetalPerformanceShaders +let openTest: Bool = false + class PreProccess: CusomKernel { init(device: MTLDevice) { let s = CusomKernel.Shape.init(inWidth: 224, inHeight: 224, inChannel: 3) @@ -26,7 +28,6 @@ class PreProccess: CusomKernel { class ViewController: UIViewController { - let device: MTLDevice! = MTLCreateSystemDefaultDevice() var textureLoader: MTKTextureLoader! // let queue: MTLCommandQueue func scaleTexture(queue: MTLCommandQueue, input: MTLTexture, complete: @escaping (MTLTexture) -> Void) { @@ -39,9 +40,9 @@ class ViewController: UIViewController { tmpTextureDes.textureType = .type2D tmpTextureDes.storageMode = .shared tmpTextureDes.cpuCacheMode = .defaultCache - let dest = device.makeTexture(descriptor: tmpTextureDes) + let dest = MetalHelper.shared.device.makeTexture(descriptor: tmpTextureDes) - let scale = MPSImageLanczosScale.init(device: device) + let scale = MPSImageLanczosScale.init(device: MetalHelper.shared.device) let buffer = queue.makeCommandBuffer() scale.encode(commandBuffer: buffer!, sourceTexture: input, destinationTexture: dest!) @@ -51,12 +52,27 @@ class ViewController: UIViewController { buffer?.commit() } + func unitTest() { + let unitTest = PaddleMobileUnitTest.init(inDevice: MetalHelper.shared.device, inQueue: MetalHelper.shared.queue) + unitTest.testConvAddBnRelu() + } + override func viewDidLoad() { super.viewDidLoad() - let queue = device.makeCommandQueue() - textureLoader = MTKTextureLoader.init(device: device) - guard let appleImage = UIImage.init(named: "apple.jpg"), let cgImage = appleImage.cgImage else { + if openTest { + print(" - testing - ") + unitTest() + return + } + + + +// return + let queue = MetalHelper.shared.queue + + textureLoader = MTKTextureLoader.init(device: MetalHelper.shared.device) + guard let appleImage = UIImage.init(named: "banana.jpeg"), let cgImage = appleImage.cgImage else { fatalError(" image nil !") } @@ -65,19 +81,18 @@ class ViewController: UIViewController { guard let inTexture = texture else { fatalError(" texture is nil !") } - - - scaleTexture(queue: queue!, input: inTexture) { (inputTexture) in + scaleTexture(queue: queue, input: inTexture) { (inputTexture) in let loader = Loader.init() do { let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" - let program = try loader.load(device: self.device, modelPath: modelPath, paraPath: paraPath) - let executor = try Executor.init(inDevice: self.device, inQueue: queue!, inProgram: program) - let preprocessKernel = PreProccess.init(device: self.device) - let output = try executor.predict(input: inputTexture, expect: [1, 224, 224, 3], preProcessKernle: preprocessKernel) - // print(output) + let program = try loader.load(device: MetalHelper.shared.device, modelPath: modelPath, paraPath: paraPath) + let executor = try Executor.init(inDevice: MetalHelper.shared.device, inQueue: queue, inProgram: program) + let preprocessKernel = PreProccess.init(device: MetalHelper.shared.device) + try executor.predict(input: inputTexture, expect: [1, 224, 224, 3], completionHandle: { (result) in + print(result.resultArr.top(r: 5)) + }, preProcessKernle: preprocessKernel) } catch let error { print(error) } diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 1eed321f54..6bceab4321 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -36,6 +36,7 @@ FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; }; FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; }; FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; }; + FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; }; FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; }; FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; }; FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; }; @@ -53,6 +54,7 @@ FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6F20F31B720007374F /* ReshapeKernel.swift */; }; FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7120F343420007374F /* ConvAddOp.swift */; }; FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7320F3437E0007374F /* ConvAddKernel.swift */; }; + FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCDC0FEA21099A1D00DC9EFB /* Tools.swift */; }; FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */; }; FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; }; FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; }; @@ -92,6 +94,7 @@ FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = ""; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = ""; }; FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = ""; }; + FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = ""; }; FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = ""; }; FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = ""; }; FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = ""; }; @@ -109,6 +112,7 @@ FCD04E6F20F31B720007374F /* ReshapeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReshapeKernel.swift; sourceTree = ""; }; FCD04E7120F343420007374F /* ConvAddOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddOp.swift; sourceTree = ""; }; FCD04E7320F3437E0007374F /* ConvAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddKernel.swift; sourceTree = ""; }; + FCDC0FEA21099A1D00DC9EFB /* Tools.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tools.swift; sourceTree = ""; }; FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = ConvAddBatchNormReluOp.swift; path = "paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"; sourceTree = SOURCE_ROOT; }; FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = ""; }; FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; }; @@ -182,7 +186,9 @@ FC039B9420E11C9A0081E9F8 /* Extensions.swift */, FC039B9520E11C9A0081E9F8 /* Errors.swift */, FC039B9620E11C9A0081E9F8 /* Types.swift */, + FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */, FC60DB8820E9AAA500FF203F /* MetalExtension.swift */, + FCDC0FEA21099A1D00DC9EFB /* Tools.swift */, ); path = Common; sourceTree = ""; @@ -374,6 +380,7 @@ FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */, + FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */, FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */, FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */, FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */, @@ -396,6 +403,7 @@ FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, + FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift b/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift index 946af08b93..62954ede17 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Extensions.swift @@ -24,6 +24,7 @@ public func ?!(option: T?, excuteOrError: @autoclosure () -> String) -> T{ if let inOpt = option { return inOpt }else{ + print(excuteOrError()) fatalError(excuteOrError()) } } @@ -90,7 +91,7 @@ extension Array where Element: Comparable{ /// /// - Parameter r: 前 r 个元素 /// - Returns: [(原有位置, 排好位置的元素)] - func top(r: Int) -> [(Int, Element)] { + public func top(r: Int) -> [(Int, Element)] { precondition(r <= self.count) return Array<(Int, Element)>(zip(0.. $1.1 }.prefix(through: r - 1)) } diff --git a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift index 5ae0d66470..af4c01d5fe 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift @@ -61,6 +61,47 @@ extension MTLDevice { } } + + func makeBuffer

(value: [P]) -> MTLBuffer { + let buffer = makeBuffer(length: value.count * MemoryLayout

.size, options: MTLResourceOptions.storageModeShared) + let contents = buffer?.contents().bindMemory(to: P.self, capacity: value.count * MemoryLayout

.size) + for i in 0..(value: [P], textureWidth: Int, textureHeight: Int, arrayLength: Int) -> MTLTexture{ + + let textureDesc = MTLTextureDescriptor.init() + textureDesc.width = textureWidth + textureDesc.height = textureHeight + textureDesc.depth = 1 + textureDesc.usage = [.shaderRead, .shaderWrite] + textureDesc.pixelFormat = .rgba32Float + textureDesc.textureType = .type2DArray + textureDesc.storageMode = .shared + textureDesc.cpuCacheMode = .defaultCache + textureDesc.arrayLength = arrayLength + let texture = makeTexture(descriptor: textureDesc)! + + if arrayLength == 1 && value.count >= 4{ + let pointer: UnsafeMutablePointer

= UnsafeMutablePointer

.allocate(capacity: value.count * MemoryLayout

.size) + for i in 0...size + let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: texture.width, height: texture.height, depth: texture.depth)) + texture.replace(region: region, mipmapLevel: 0, withBytes: pointer, bytesPerRow: bytesPerRow) + } else { + + + + } + + return texture + } } extension MTLComputeCommandEncoder { @@ -79,63 +120,117 @@ extension MTLComputeCommandEncoder { let groupDepth = slices let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth) -// print("groups: \(groups) ") + print("groups: \(groups) ") + print("threads per group: \(threadsPerGroup)") setComputePipelineState(computePipline) + dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup) } } public extension MTLTexture { - func logDesc(header: String = "", stridable: Bool = true) -> T? { - print(header) - print("texture: \(self)") + + func stridableFloatArray

(stridable: Bool = true) -> [(index: Int, value: P)] { + var arr: [P] = floatArray { (p: P) -> P in + return p; + } + var result: [(index: Int, value: P)] = [] + if arr.count > 100 && stridable { + for j in stride(from: 0, to: arr.count , by: arr.count / 100){ + result.append((j, arr[j])) + } + } else { + for j in 0..(res: (P) -> T) -> [T] { + var fArr: [T] = [] if textureType == .type2DArray { for i in 0...size, alignment: MemoryLayout.alignment) - let bytesPerRow = width * depth * 4 * MemoryLayout.size - let bytesPerImage = width * height * depth * 4 * MemoryLayout.size + let bytes = UnsafeMutableRawPointer.allocate(byteCount: width * height * 4 * MemoryLayout

.size, alignment: MemoryLayout

.alignment) + let bytesPerRow = width * depth * 4 * MemoryLayout

.size + let bytesPerImage = width * height * depth * 4 * MemoryLayout

.size let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth)) getBytes(bytes, bytesPerRow: bytesPerRow, bytesPerImage: bytesPerImage, from: region, mipmapLevel: 0, slice: i) - let p = bytes.assumingMemoryBound(to: T.self) - str += "2d array count : \(width * height * depth * 4) \n" - if stridable && width * height * depth * 4 > 100 { - for j in stride(from: 0, to: width * height * depth * 4 , by: width * height * depth * 4 / 100){ - str += " index \(j): \(p[j])" - } - } else { - for j in 0...size, alignment: MemoryLayout.alignment) - let bytesPerRow = width * depth * 4 * MemoryLayout.size + let bytes = UnsafeMutableRawPointer.allocate(byteCount: width * height * 4 * MemoryLayout

.size, alignment: MemoryLayout

.alignment) + let bytesPerRow = width * depth * 4 * MemoryLayout

.size let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth)) getBytes(bytes, bytesPerRow: bytesPerRow, from: region, mipmapLevel: 0) - let p = bytes.assumingMemoryBound(to: T.self) - str += "2d count : \(width * width * 4) \n" - - if stridable { - for j in stride(from: 0, to: width * height * 4, by: width * height * 4 / 100){ - str += " \(p[j])" - } - } else { - for j in 0..(header: String = "", stridable: Bool = true) -> T? { + print(header) + print("texture: \(self)") + let res: [(index: Int, value: T)] = stridableFloatArray(stridable: stridable) + print(res) + +// if textureType == .type2DArray { +// for i in 0...size, alignment: MemoryLayout.alignment) +// let bytesPerRow = width * depth * 4 * MemoryLayout.size +// let bytesPerImage = width * height * depth * 4 * MemoryLayout.size +// let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth)) +// getBytes(bytes, bytesPerRow: bytesPerRow, bytesPerImage: bytesPerImage, from: region, mipmapLevel: 0, slice: i) +// let p = bytes.assumingMemoryBound(to: T.self) +// str += "2d array count : \(width * height * depth * 4) \n" +// if stridable && width * height * depth * 4 > 100 { +// for j in stride(from: 0, to: width * height * depth * 4 , by: width * height * depth * 4 / 100){ +// str += " index \(j): \(p[j])" +// } +// } else { +// for j in 0...size, alignment: MemoryLayout.alignment) +// let bytesPerRow = width * depth * 4 * MemoryLayout.size +// let region = MTLRegion.init(origin: MTLOrigin.init(x: 0, y: 0, z: 0), size: MTLSize.init(width: width, height: height, depth: depth)) +// getBytes(bytes, bytesPerRow: bytesPerRow, from: region, mipmapLevel: 0) +// let p = bytes.assumingMemoryBound(to: T.self) +// str += "2d count : \(width * width * 4) \n" +// +// if stridable { +// for j in stride(from: 0, to: width * height * 4, by: width * height * 4 / 100){ +// str += "index \(j): \(p[j]) " +// } +// } else { +// for j in 0.. MTLTexture { + let textureDesc = MTLTextureDescriptor.init() + textureDesc.width = textureWidth + textureDesc.height = textureHeight + textureDesc.depth = 1 + textureDesc.usage = [.shaderRead, .shaderWrite] + textureDesc.pixelFormat = .rgba32Float + textureDesc.textureType = .type2DArray + textureDesc.storageMode = .shared + textureDesc.cpuCacheMode = .defaultCache + textureDesc.arrayLength = arrayLength + let texture = makeTexture(descriptor: textureDesc, offset: 0, bytesPerRow: textureWidth * 4 * 4)! + return texture + } + + } diff --git a/metal/paddle-mobile/paddle-mobile/Common/PaddleMobileUnitTest.swift b/metal/paddle-mobile/paddle-mobile/Common/PaddleMobileUnitTest.swift new file mode 100644 index 0000000000..a2927c4693 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Common/PaddleMobileUnitTest.swift @@ -0,0 +1,149 @@ +// +// TestConvAddBatchNormRelu.swift +// paddle-mobile-demo +// +// Created by liuRuiLong on 2018/7/25. +// Copyright © 2018年 orange. All rights reserved. +// + +import Metal +import Foundation + +public class PaddleMobileUnitTest { + let device: MTLDevice + let queue: MTLCommandQueue + public init(inDevice: MTLDevice, inQueue: MTLCommandQueue) { + device = inDevice + queue = inQueue + } + + public func testConvAddBnRelu() { + let buffer = queue.makeCommandBuffer() ?! " buffer is nil " + + let input: [Float32] = [ + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + 1.0, 2.0, 3.0, 4.0, + ] + + let filter: [Float32] = [ + //1.0 + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + //2.0 + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + //3.0 + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + //4.0 + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, + ] + + let biase: [Float32] = [1.0, 1.0, 1.0, 100.0] + let newScalue: [Float32] = [1.0, 1.0, 1.0, 1.0] + let newBiase: [Float32] = [1.0, 1.0, 1.0, 1.0] + + let inputeTexture = device.makeFloatTexture(value: input, textureWidth: 3, textureHeight: 3, arrayLength: 1) + + //filter + let filterBuffer = device.makeBuffer(value: filter) + + // biase + let biaseBuffer = device.makeBuffer(value: biase) + + // new scale + let newScalueBuffer = device.makeBuffer(value: newScalue) + + // new biase + let newBiaseBuffer = device.makeBuffer(value: newBiase) + + //output + let outputTexture = device.makeFloatTexture(value: [Float32](), textureWidth: 2, textureHeight: 2, arrayLength: 1) + + let filterSize: (width: Int, height: Int, channel: Int) = (3, 3, 4) + let paddings: (Int, Int) = (1, 1) + let stride: (Int, Int) = (2, 2) + + let offsetX = filterSize.width/2 - paddings.0 + let offsetY = filterSize.height/2 - paddings.1 + + let metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: 0, strideX: UInt16(stride.0), strideY: UInt16(stride.1), paddedZ: UInt16(paddings.0)) + + let param = ConvAddBatchNormReluTestParam.init(inInputTexture: inputeTexture, inOutputTexture: outputTexture, inMetalParam: metalParam, inFilterBuffer: filterBuffer, inBiaseBuffer: biaseBuffer, inNewScaleBuffer: newScalueBuffer, inNewBiaseBuffer: newBiaseBuffer, inFilterSize: filterSize) + + + + let convAddBnReluKernel = ConvAddBatchNormReluKernel.init(device: device, testParam: param) + + convAddBnReluKernel.test(commandBuffer: buffer, param: param) + + buffer.addCompletedHandler { (buffer) in + let _: Float32? = inputeTexture.logDesc(header: "input texture", stridable: false) + let _: Float32? = outputTexture.logDesc(header: "output texture", stridable: false) + } + + buffer.commit() + + +// let inputTexture = device.makeFloatTexture(value: <#T##[P]#>, textureWidth: <#T##Int#>, textureHeight: <#T##Int#>, arrayLength: <#T##Int#>) + + +// let param = ConvAddBatchNormReluTestParam.init(inInputTexture: <#T##MTLTexture#>, inOutputTexture: <#T##MTLTexture#>, inMetalParam: <#T##MetalConvParam#>, inFilterBuffer: <#T##MTLBuffer#>, inBiaseBuffer: <#T##MTLBuffer#>, inNewScaleBuffer: <#T##MTLBuffer#>, inNewBiaseBuffer: <#T##MTLBuffer#>, inFilterSize: <#T##(width: Int, height: Int, channel: Int)#>) + +// ConvAddBatchNormReluKernel.init(device: <#T##MTLDevice#>, testParam: <#T##ConvAddBatchNormReluTestParam#>) + + + } +} + + + diff --git a/metal/paddle-mobile/paddle-mobile/Common/Tools.swift b/metal/paddle-mobile/paddle-mobile/Common/Tools.swift new file mode 100644 index 0000000000..cc1f7a4f21 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Common/Tools.swift @@ -0,0 +1,21 @@ +// +// Tools.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/26. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + + +func writeToLibrary(fileName: String, array: [P]) { + let libraryPath = NSSearchPathForDirectoriesInDomains(.libraryDirectory, .userDomainMask, true).last ?! " library path get error " + let filePath = libraryPath + "/" + fileName + let fileManager = FileManager.init() + fileManager.createFile(atPath: filePath, contents: nil, attributes: nil) + let fileHandler = FileHandle.init(forWritingAtPath: filePath) ?! " file handler nil " + let data = Data.init(buffer: UnsafeBufferPointer.init(start: array, count: array.count)) + fileHandler.write(data) + fileHandler.closeFile() +} diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 00084a63cb..e883754b2c 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -17,6 +17,7 @@ import Foundation public class ResultHolder { public let dim: [Int] public let resultArr: [P] + public init(inDim: [Int], inResult: [P]) { dim = inDim resultArr = inResult @@ -56,7 +57,7 @@ public class Executor { queue = inQueue for block in inProgram.programDesc.blocks { //block.ops.count - for i in 0...shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) @@ -79,12 +80,11 @@ public class Executor { } } - public func predict(input: MTLTexture, expect: [Int], preProcessKernle: CusomKernel? = nil) throws -> ResultHolder

{ + public func predict(input: MTLTexture, expect: [Int], completionHandle: @escaping (ResultHolder

) -> Void, preProcessKernle: CusomKernel? = nil) throws { guard let buffer = queue.makeCommandBuffer() else { throw PaddleMobileError.predictError(message: "CommandBuffer is nil") } let resInput: MTLTexture - if let inPre = preProcessKernle { do { try inPre.compute(inputTexuture: input, commandBuffer: buffer) @@ -109,26 +109,36 @@ public class Executor { } buffer.addCompletedHandler { (commandbuffer) in + let inputArr = resInput.floatArray(res: { (p:P) -> P in + return p + }) +// print(inputArr) + +// let stridableInput: [(index: Int, value: Float)] = input.stridableFloatArray() +// print(stridableInput) + +// let _: Flo? = input.logDesc(header: "input: ", stridable: true) for op in self.ops { op.delogOutput() } + return + guard let outputVar = self.program.scope.output() else { + fatalError("output nil") + } + + guard let output = outputVar as? Texture

else { + fatalError("output var type error") + } + let resultHodlder = ResultHolder

.init(inDim: output.dim.dims, inResult: output.metalTexture.floatArray(res: { (p:P) -> P in + return p + })) + completionHandle(resultHodlder) let afterDate = Date.init() print(" encoder end ! time: \(afterDate.timeIntervalSince(beforeDate))") } buffer.commit() - - guard let outputVar = program.scope.output() else { - throw PaddleMobileError.netError(message: "output nil") - } - - guard let output = outputVar as? ResultHolder

else { - throw PaddleMobileError.netError(message: "output var type error") - } - - return output } - } //public let paddle_executor: Executor = Executor.init() diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index 3ca41ed724..bc95f84d8a 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -65,6 +65,7 @@ protocol OperatorProtocol { associatedtype ParamType associatedtype KerType: Computable where Self.KerType.ParamType == ParamType var type: String { get } + var scope: Scope { get } var inputs: [String : [String]] { get } var paraInputs: [String : [String]] { get set } var outpus: [String : [String]] { get } @@ -93,9 +94,11 @@ class Operator : OperatorProtocol where let outpus: [String : [String]] let attrs: [String : Attr] let para: ParamType + let scope: Scope var kernel: KerType required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws { type = opDesc.type + scope = inScope inputs = opDesc.inputs outpus = opDesc.outputs attrs = opDesc.attrs diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift index 0968888c8b..8746ba980d 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvAddBatchNormReluOp.swift @@ -112,12 +112,13 @@ class ConvAddBatchNormReluOp: Operator: Operator, FeedParam< func delogOutput() { // para.input.mtlTexture.logDesc() // let _: P? = para.input.mtlTexture.logDesc(header: "feed input: ", stridable: true) -// let _: P? = para.output.metalTexture.logDesc(header: "feed output: ", stridable: true) +// let _: P? = para.output.metalTexture.logDesc(header: "feed output: ", stridable: false) } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift index 3bc2bb5cb1..2964b89e5d 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift @@ -15,13 +15,14 @@ import Foundation class FetchParam: OpParam{ - var output: ResultHolder

= ResultHolder.init(inDim: [], inResult: []) + var output: Texture

let input: Texture

let scope: Scope required init(opDesc: OpDesc, inScope: Scope) throws { scope = inScope do { input = try FetchParam.inputX(inputs: opDesc.inputs, from: inScope) + output = input } catch let error { throw error } @@ -47,6 +48,7 @@ class FetchOp: Operator< FetchKernel

, FetchParam

>, Runab typealias OpType = FetchOp

func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + scope.setOutput(output: para.output) } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift index cac707dfa9..e8ee935390 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvAddBatchNormReluKernel.swift @@ -14,7 +14,38 @@ import Foundation -class ConvAddBatchNormReluKernel: Kernel, Computable { +struct ConvAddBatchNormReluTestParam: TestParam { + let inputTexture: MTLTexture + let outputTexture: MTLTexture + var metalParam: MetalConvParam + let filterBuffer: MTLBuffer + let biaseBuffer: MTLBuffer + let newScaleBuffer: MTLBuffer + let newBiaseBuffer: MTLBuffer + let filterSize: (width: Int, height: Int, channel: Int) + init(inInputTexture: MTLTexture, inOutputTexture: MTLTexture, inMetalParam: MetalConvParam, inFilterBuffer: MTLBuffer, inBiaseBuffer: MTLBuffer, inNewScaleBuffer: MTLBuffer, inNewBiaseBuffer: MTLBuffer, inFilterSize: (width: Int, height: Int, channel: Int)) { + inputTexture = inInputTexture + outputTexture = inOutputTexture + metalParam = inMetalParam + filterBuffer = inFilterBuffer + biaseBuffer = inBiaseBuffer + newScaleBuffer = inNewScaleBuffer + newBiaseBuffer = inNewBiaseBuffer + filterSize = inFilterSize + } +} + +class ConvAddBatchNormReluKernel: Kernel, Computable, Testable { + required init(device: MTLDevice, testParam: ConvAddBatchNormReluTestParam) { + if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 { + super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1") + } else if testParam.filterSize.channel == 1 { + super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3") + } else { + super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") + } + } + var metalParam: MetalConvParam! required init(device: MTLDevice, param: ConvAddBatchNormReluParam

) { @@ -27,7 +58,6 @@ class ConvAddBatchNormReluKernel: Kernel, Computable { super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") } - let offsetX = param.filter.width/2 - Int(param.paddings[0]) let offsetY = param.filter.height/2 - Int(param.paddings[1]) @@ -69,7 +99,7 @@ class ConvAddBatchNormReluKernel: Kernel, Computable { guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encode is nil") } - + print("ConvAddBatchNormReluKernel compute") encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.output.metalTexture, index: 1) @@ -81,4 +111,22 @@ class ConvAddBatchNormReluKernel: Kernel, Computable { encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.endEncoding() } + + public func test(commandBuffer: MTLCommandBuffer, param: ConvAddBatchNormReluTestParam) { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + fatalError() + } + + print("ConvAddBatchNormReluKernel compute") + encoder.setTexture(param.inputTexture, index: 0) + encoder.setTexture(param.outputTexture, index: 1) + var inMetalParam = param.metalParam + encoder.setBytes(&inMetalParam, length: MemoryLayout.size, index: 0) + encoder.setBuffer(param.filterBuffer, offset: 0, index: 1) + encoder.setBuffer(param.biaseBuffer, offset: 0, index: 2) + encoder.setBuffer(param.newScaleBuffer, offset: 0, index: 3) + encoder.setBuffer(param.newBiaseBuffer, offset: 0, index: 4) + encoder.dispatch(computePipline: pipline, outTexture: param.outputTexture) + encoder.endEncoding() + } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal index 660235eb14..a738d55e39 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.metal @@ -81,10 +81,11 @@ kernel void conv_add_batch_norm_relu_3x3(texture2d_array return; } - short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY); + ushort2 stride = ushort2(param.strideX, param.strideY); + const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY); + constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); const uint kernelHXW = 9; - uint input_arr_size = inTexture.get_array_size(); uint weithTo = gid.z * kernelHXW * input_arr_size * 4; @@ -134,7 +135,9 @@ kernel void conv_add_batch_norm_relu_1x1(texture2d_array return; } - short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY); + ushort2 stride = ushort2(param.strideX, param.strideY); + ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY); + constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); const uint kernelHXW = 1; @@ -175,7 +178,9 @@ kernel void conv_add_1x1(texture2d_array inTexture [[text return; } - short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY); + ushort2 stride = ushort2(param.strideX, param.strideY); + ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY); + constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); const uint kernelHXW = 1; @@ -219,7 +224,9 @@ kernel void depthwise_conv_add_batch_norm_relu_3x3(texture2d_array inTexture [[texture(0)]], } kernel void relu(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - uint3 gid [[thread_position_in_grid]]) { + texture2d_array outTexture [[texture(1)]], + uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; @@ -119,7 +119,7 @@ kernel void pool(texture2d_array inTexture [[texture(0)]], int ymin = gid.y * pm.strideX - pm.paddingX; int ymax = min(ymin + pm.ksizeX, int(inTexture.get_height())); ymin = max(ymin, 0); - + float4 r = 0; if (pm.poolType == 0) { r = inTexture.read(uint2(xmin, ymin), gid.z); @@ -136,11 +136,6 @@ kernel void pool(texture2d_array inTexture [[texture(0)]], } r /= pm.ksizeX * pm.ksizeY; } -// float4 r; -// r[0] = 1.0 * pm.ksizeX; -// r[1] = 2.0; -// r[2] = 3.0; -// r[3] = 4.0; outTexture.write(r, gid.xy, gid.z); } @@ -151,7 +146,7 @@ kernel void reshape(texture2d_array inTexture [[texture(0)] if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; - + float4 r = inTexture.read(uint2(0, 0), gid.z); outTexture.write(r, gid.xy, gid.z); } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift index 1a0e5b2536..672484cd9d 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift @@ -39,7 +39,8 @@ public struct Dim { return dims[index]; } - private var dims: [Int] + + private(set) var dims: [Int] private init(){ fatalError() } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index eae373a141..50f9f7d067 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -22,6 +22,7 @@ class InputTexture { mtlTexture = inMTLTexture expectDim = inExpectDim } + } extension InputTexture { -- GitLab