From fb7c6e264484b2430266de5bdfedc6078dd68ffa Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Wed, 12 Sep 2018 22:11:00 +0800 Subject: [PATCH] batchnorm --- .../paddle-mobile.xcodeproj/project.pbxproj | 20 +- .../paddle-mobile/Operators/BatchNormOp.swift | 19 +- .../paddle-mobile/Operators/FlattenOp.swift | 19 +- .../Operators/Kernels/BatchNormKernel.swift | 55 ++--- .../Operators/Kernels/ConcatKernel.swift | 5 +- .../Operators/Kernels/FlattenKernel.swift | 71 ++++++ .../Operators/Kernels/ReshapeKernel.swift | 6 +- .../Operators/Kernels/SplitKernel.swift | 5 +- .../Kernels/metal/BatchNormKernel.metal | 26 +- ...ernel.metal.inc => ConcatKernel.inc.metal} | 23 +- .../Kernels/metal/ConcatKernel.metal | 48 ++-- ...rnel.metal.inc => ReshapeKernel.inc.metal} | 19 +- .../Kernels/metal/ReshapeKernel.metal | 227 +++++++++--------- .../paddle-mobile/Operators/ReshapeOp.swift | 3 - .../paddle-mobile/Operators/ShapeOp.swift | 8 +- .../paddle-mobile/Operators/SplitOp.swift | 23 +- 16 files changed, 334 insertions(+), 243 deletions(-) create mode 100644 metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift rename metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/{ConcatKernel.metal.inc => ConcatKernel.inc.metal} (87%) rename metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/{ReshapeKernel.metal.inc => ReshapeKernel.inc.metal} (82%) diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 56b8e363e9..bf6b480580 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -16,8 +16,9 @@ 4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; }; 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; - 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */; }; - 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */; }; + 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; }; + 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; }; + 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; @@ -126,8 +127,9 @@ 4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = ""; }; 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = ""; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = ""; }; - 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.metal.inc; sourceTree = ""; }; - 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.metal.inc; sourceTree = ""; }; + 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = ""; }; + 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = ""; }; + 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = ""; }; @@ -395,6 +397,7 @@ FCD04E6720F315020007374F /* PoolKernel.swift */, FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */, FCD04E6F20F31B720007374F /* ReshapeKernel.swift */, + 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */, FCD04E7320F3437E0007374F /* ConvAddKernel.swift */, FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */, FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */, @@ -442,7 +445,7 @@ children = ( FC27990D21341016000B6BAD /* BoxCoder.metal */, 4AF928812135673D005B6C3A /* ConcatKernel.metal */, - 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */, + 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */, 4AF9288321357BE3005B6C3A /* Elementwise.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, @@ -455,7 +458,7 @@ FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */, FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */, FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */, - 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */, + 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */, FCA3A1642132A5EB00084FE5 /* Common.metal */, FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */, FCA67CD42138272900BD58AA /* ConvAddMetal.metal */, @@ -477,7 +480,7 @@ FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */, FC292C85214257CB00CF622F /* CPUCompute.h in Headers */, FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */, - 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */, + 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */, FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */, ); runOnlyForDeploymentPostprocessing = 0; @@ -617,6 +620,7 @@ FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */, + 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */, 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, @@ -657,7 +661,7 @@ FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */, - 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */, + 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */, FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */, FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index cd7bebaf40..38563c51dd 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -19,11 +19,14 @@ class BatchNormParam: OpParam { required init(opDesc: OpDesc, inScope: Scope) throws { do { input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope) + if input.transpose != [0, 2, 3, 1] { + fatalError("batch norm only accepts NHWC") + } output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope) - inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope) - inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope) - inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope) - inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) + bias = try BatchNormParam.getFirstTensor(key: "Bias", map: opDesc.paraInputs, from: inScope) + mean = try BatchNormParam.getFirstTensor(key: "Mean", map: opDesc.paraInputs, from: inScope) + scale = try BatchNormParam.getFirstTensor(key: "Scale", map: opDesc.paraInputs, from: inScope) + variance = try BatchNormParam.getFirstTensor(key: "Variance", map: opDesc.paraInputs, from: inScope) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) } catch let error { @@ -32,10 +35,10 @@ class BatchNormParam: OpParam { } let input: Texture

var output: Texture

- let inputBias: Tensor - let inputMean: Tensor - let inputScale: Tensor - let inputVariance: Tensor + let bias: Tensor

+ let mean: Tensor

+ let scale: Tensor

+ let variance: Tensor

let epsilon: Float let momentum: Float } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift index 70dd1c0fc8..2abb3a11a2 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift @@ -14,7 +14,24 @@ import Foundation -class FlattenOp: Operator, ReshapeParam

>, Runable, Creator, InferShaperable{ +class FlattenParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope) + output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope) + axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs) + } catch let error { + throw error + } + } + let input: Texture

+ var output: Texture

+ let axis: Int +} + + +class FlattenOp: Operator, FlattenParam

>, Runable, Creator, InferShaperable{ typealias OpType = FlattenOp

diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift index b80a47516e..caa56ba256 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift @@ -15,20 +15,20 @@ import Foundation class BatchNormKernel: Kernel, Computable { -// var newScale: MTLBuffer -// var newBias: MTLBuffer -// required init(device: MTLDevice, param: BatchNormParam

) { -// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { -// fatalError() -// } -// -// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { -// fatalError() -// } -// self.newScale = newScale -// self.newBias = newBias -// + let count = param.variance.dim.numel() + let varianceP = param.variance.data.pointer + let meanP = param.mean.data.pointer + let scaleP = param.scale.data.pointer + let biasP = param.scale.data.pointer + for i in 0..: Kernel, Computable { } else { fatalError() } -// -// let varianceBuffer : MTLBuffer = param.inputVariance.buffer -// -// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) -// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) -// for i in 0..<(varianceBuffer.length / MemoryLayout

.stride) { -// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot() -// } -// -// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) -// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) -// let scale : MTLBuffer = param.inputScale.buffer -// let scaleContents = scale.contents().assumingMemoryBound(to: P.self) -// let bias : MTLBuffer = param.inputBias.buffer -// let biasContents = bias.contents().assumingMemoryBound(to: P.self) -// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) -// -// for i in 0..<(newScale.length / MemoryLayout

.stride) { -// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i])) -// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i])) -// } } func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam

) throws { guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encoder is nil") } -// encoder.setTexture(param.input.metalTexture, index: 0) -// encoder.setTexture(param.output.metalTexture, index: 1) -// encoder.setBuffer(newScale, offset: 0, index: 0) -// encoder.setBuffer(newBias, offset: 0, index: 1) + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.setBuffer(param.scale.buffer, offset: 0, index: 0) + encoder.setBuffer(param.bias.buffer, offset: 0, index: 1) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.endEncoding() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift index 644476ad9d..7d1a54a549 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConcatKernel.swift @@ -122,10 +122,11 @@ class ConcatKernel: Kernel, Computable{ required init(device: MTLDevice, param: ConcatParam

) { param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision) + let orank = param.output.tensorDim.cout() if computePrecision == .Float32 { - super.init(device: device, inFunctionName: "concat") + super.init(device: device, inFunctionName: "concat_\(orank)_float") } else if computePrecision == .Float16 { - super.init(device: device, inFunctionName: "concat_half") + super.init(device: device, inFunctionName: "concat_\(orank)_half") } else { fatalError() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift new file mode 100644 index 0000000000..090c55b161 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct FlattenMetalParam { + var idim: (Int32, Int32, Int32, Int32) + var itrans: (Int32, Int32, Int32, Int32) + var odim: (Int32, Int32, Int32, Int32) + var otrans: (Int32, Int32, Int32, Int32) +} + + +class FlattenKernel: Kernel, Computable{ + + var metalParam: FlattenMetalParam + + required init(device: MTLDevice, param: FlattenParam

) { + param.output.initTexture(device: device, computePrecision: computePrecision) + var id: [Int32] = [1, 1, 1, 1] + for i in 0..) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encoder is nil") + } + + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + + encoder.setBytes(&metalParam, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift index 91708ff708..b7cf34eac4 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReshapeKernel.swift @@ -49,10 +49,12 @@ class ReshapeKernel: Kernel, Computable{ odim: (od[0], od[1], od[2], od[3]), otrans: (ot[0], ot[1], ot[2], ot[3]) ) + let irank = param.input.tensorDim.cout() + let orank = param.output.tensorDim.cout() if computePrecision == .Float32 { - super.init(device: device, inFunctionName: "reshape") + super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float") } else if computePrecision == .Float16 { - super.init(device: device, inFunctionName: "reshape_half") + super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half") } else { fatalError() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift index a25a700640..9912c5aee1 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift @@ -27,7 +27,10 @@ class SplitKernel: Kernel, Computable{ } required init(device: MTLDevice, param: SplitParam

) { - param.output.initTexture(device: device, computePrecision: computePrecision) + // param.output.initTexture(device: device, computePrecision: computePrecision) + for output in param.outputList { + output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) + } if computePrecision == .Float32 { super.init(device: device, inFunctionName: "split") } else if computePrecision == .Float16 { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal index 2311836eef..657187211e 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BatchNormKernel.metal @@ -15,28 +15,28 @@ #include using namespace metal; -kernel void batchnorm_half(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - const device half4 * newScale [[buffer(0)]], - const device half4 * newBias [[buffer(1)]], +kernel void batchnorm(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + const device float4 * newScale [[buffer(0)]], + const device float4 * newBias [[buffer(1)]], uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; - const half4 input = inTexture.read(gid.xy, gid.z); - half4 output = input * newScale[gid.z] + newBias[gid.z]; + const float4 input = inTexture.read(gid.xy, gid.z); + float4 output = input * newScale[gid.z] + newBias[gid.z]; outTexture.write(output, gid.xy, gid.z); } -kernel void batchnorm(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - const device float4 * newScale [[buffer(0)]], - const device float4 * newBias [[buffer(1)]], - uint3 gid [[thread_position_in_grid]]) { +kernel void batchnorm_half(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + const device half4 * newScale [[buffer(0)]], + const device half4 * newBias [[buffer(1)]], + uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || gid.z >= outTexture.get_array_size()) return; - const float4 input = inTexture.read(gid.xy, gid.z); - float4 output = input * newScale[gid.z] + newBias[gid.z]; + const half4 input = inTexture.read(gid.xy, gid.z); + half4 output = input * newScale[gid.z] + newBias[gid.z]; outTexture.write(output, gid.xy, gid.z); } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal similarity index 87% rename from metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc rename to metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal index b473ea6c6d..777b252b5b 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.inc.metal @@ -1,20 +1,14 @@ -#ifndef D -#define D 4 -#endif - -#ifndef P -#define P float -#endif +#ifdef P #define CONCAT2(a, b) a ## b #define CONCAT2_(a, b) a ## _ ## b #define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c -#define FUNC(f, d, p) CONCAT3_(f, d, p) +#define FUNC(f, r, p) CONCAT3_(f, r, p) #define VECTOR(p, n) CONCAT2(p, n) -#define FUNC_D(f, d) CONCAT2_(f, d) +#define FUNC_R(f, r) CONCAT2_(f, r) -kernel void FUNC(concat, D, P)(texture2d_array in0 [[texture(0)]], +kernel void FUNC(concat, R, P)(texture2d_array in0 [[texture(0)]], texture2d_array in1 [[texture(1)]], texture2d_array in2 [[texture(2)]], texture2d_array in3 [[texture(3)]], @@ -29,10 +23,10 @@ kernel void FUNC(concat, D, P)(texture2d_array in0 [[texture(0) VECTOR(P, 4) r; for (int i = 0; i < 4; i++) { xyzn[3] = i; -#if D == 4 +#if R == 4 xyzn2abcd_4(cp.odim[3], xyzn, abcd); #else - FUNC_D(xyzn2abcd, D)(xyzn, abcd); + FUNC_R(xyzn2abcd, R)(xyzn, abcd); #endif int k = abcd[cp.axis] - cp.offset; int j = 0; @@ -48,10 +42,10 @@ kernel void FUNC(concat, D, P)(texture2d_array in0 [[texture(0) int ta = cp.odim[cp.axis]; abcd[cp.axis] = k; cp.odim[cp.axis] = cp.vdim[j]; -#if D == 4 +#if R == 4 abcd2xyzn_4(cp.odim[3], abcd, oxyzn); #else - FUNC_D(abcd2xyzn, D)(abcd, oxyzn); + FUNC_R(abcd2xyzn, R)(abcd, oxyzn); #endif cp.odim[cp.axis] = ta; switch (j) { @@ -66,3 +60,4 @@ kernel void FUNC(concat, D, P)(texture2d_array in0 [[texture(0) } out.write(r, gid.xy, gid.z); } +#endif diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal index 65a01182d2..8bd41feefc 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal @@ -26,31 +26,31 @@ struct ConcatParam { }; #define P float -#define D 4 -#include "ConcatKernel.metal.inc" -#undef D -#define D 3 -#include "ConcatKernel.metal.inc" -#undef D -#define D 2 -#include "ConcatKernel.metal.inc" -#undef D -#define D 1 -#include "ConcatKernel.metal.inc" -#undef D +#define R 4 +#include "ConcatKernel.inc.metal" +#undef R +#define R 3 +#include "ConcatKernel.inc.metal" +#undef R +#define R 2 +#include "ConcatKernel.inc.metal" +#undef R +#define R 1 +#include "ConcatKernel.inc.metal" +#undef R #undef P #define P half -#define D 4 -#include "ConcatKernel.metal.inc" -#undef D -#define D 3 -#include "ConcatKernel.metal.inc" -#undef D -#define D 2 -#include "ConcatKernel.metal.inc" -#undef D -#define D 1 -#include "ConcatKernel.metal.inc" -#undef D +#define R 4 +#include "ConcatKernel.inc.metal" +#undef R +#define R 3 +#include "ConcatKernel.inc.metal" +#undef R +#define R 2 +#include "ConcatKernel.inc.metal" +#undef R +#define R 1 +#include "ConcatKernel.inc.metal" +#undef R #undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal similarity index 82% rename from metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc rename to metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal index b5e64aa774..3d6c141210 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.inc.metal @@ -1,17 +1,15 @@ -#ifndef P -#define P float -#endif +#ifdef P #define CONCAT2(a, b) a ## b #define CONCAT2_(a, b) a ## _ ## b #define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c #define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d -#define FUNC(f, d1, d2, p) CONCAT4_(f, d1, d2, p) +#define FUNC(f, r1, r2, p) CONCAT4_(f, r1, r2, p) #define VECTOR(p, n) CONCAT2(p, n) -#define FUNC_D(f, d) CONCAT2_(f, d) +#define FUNC_R(f, r) CONCAT2_(f, r) -kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array inTexture [[texture(0)]], +kernel void FUNC(reshape, RIN, ROUT, P)(texture2d_array inTexture [[texture(0)]], texture2d_array outTexture [[texture(1)]], constant ReshapeParam &rp [[buffer(0)]], uint3 gid [[thread_position_in_grid]]) { @@ -27,10 +25,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array inTextu VECTOR(P, 4) r; for (int n = 0; n < 4; n++) { oxyzn[3] = n; -#if DOUT == 4 +#if ROUT == 4 xyzn2abcd_4(oC, oxyzn, oabcd); #else - FUNC_D(xyzn2abcd, DOUT)(oxyzn, oabcd); + FUNC_R(xyzn2abcd, ROUT)(oxyzn, oabcd); #endif int tabcd[4]; invtrans(lrp.otrans, oabcd, tabcd); @@ -39,10 +37,10 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array inTextu index2abcd(lrp.idim, index, tabcd); trans(lrp.itrans, tabcd, iabcd); abcd2xyzn(iC, iabcd, ixyzn); -#if DIN == 4 +#if RIN == 4 abcd2xyzn_4(iC, iabcd, ixyzn); #else - FUNC_D(abcd2xyzn, DIN)(iabcd, ixyzn); + FUNC_R(abcd2xyzn, RIN)(iabcd, ixyzn); #endif r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; } else { @@ -52,3 +50,4 @@ kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array inTextu outTexture.write(r, gid.xy, gid.z); } +#endif diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal index 75337990c3..d2f5815d42 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal @@ -8,7 +8,7 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + WITHOUT WARRANTIES OR CONRITIONS OF ANY KINR, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ @@ -25,127 +25,126 @@ struct ReshapeParam { }; #define P float -#define DIN 4 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 4 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN -#define DIN 3 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 3 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN -#define DIN 2 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 2 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN -#define DIN 1 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 1 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN #undef P #define P half -#define DIN 4 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 4 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN -#define DIN 3 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 3 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN -#define DIN 2 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN - -#define DIN 1 -#define DOUT 4 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 3 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 2 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#define DOUT 1 -#include "ReshapeKernel.metal.inc" -#undef DOUT -#undef DIN +#define RIN 2 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN +#define RIN 1 +#define ROUT 4 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 3 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 2 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#define ROUT 1 +#include "ReshapeKernel.inc.metal" +#undef ROUT +#undef RIN #undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift index 1c1da9901d..bd257a65f3 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift @@ -43,15 +43,12 @@ class ReshapeParam: OpParam { } output.padToFourDim = Dim.init(inDim: dim) output.dim = output.padToFourDim - -// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs) } catch let error { throw error } } let input: Texture

let shape: [Int32] -// let inplace: Bool var output: Texture

} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift index 7af5562040..daebb37ade 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift @@ -18,17 +18,19 @@ class ShapeParam: OpParam { typealias ParamPrecisionType = P required init(opDesc: OpDesc, inScope: Scope) throws { do { - output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope) + input = try ShapeParam.input(inputs: opDesc.inputs, from: inScope) + output = try ShapeParam.outputOut(outputs: opDesc.outputs, from: inScope) } catch let error { throw error } } var output: Texture

+ let input: Texture

} -class ShapeOp: Operator, SplitParam

>, Runable, Creator, InferShaperable{ +class ShapeOp: Operator, ShapeParam

>, Runable, Creator, InferShaperable{ - typealias OpType = SplitOp

+ typealias OpType = ShapeOp

func inferShape() { // para.output.dim = para.input.dim diff --git a/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift index 5adc47c663..41bf6784f5 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift @@ -18,13 +18,32 @@ class SplitParam: OpParam { typealias ParamPrecisionType = P required init(opDesc: OpDesc, inScope: Scope) throws { do { -// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope) - output = try SplitParam.outputOut(outputs: opDesc.outputs, from: inScope) + input = try SplitParam.inputX(inputs: opDesc.inputs, from: inScope) + output = Texture

.init(device: input.metalTexture!.device, inDim: input.dim) + axis = try SplitParam.getAttr(key: "axis", attrs: opDesc.attrs) + sections = try SplitParam.getAttr(key: "sections", attrs: opDesc.attrs) + if axis < 0 { + axis = input.tensorDim.cout() + axis + } + guard let outlist = opDesc.outputs["Out"] else { + fatalError() + } + for out in outlist { + guard let variant = inScope[out], let v = variant as? Texture

else { + fatalError() + } + outputList.append(v) + sections.append(Int32(v.tensorDim.dims[axis])) + } } catch let error { throw error } } + var axis: Int + let input: Texture

var output: Texture

+ var outputList: [Texture

] = [] + var sections: [Int32] = [] } class SplitOp: Operator, SplitParam

>, Runable, Creator, InferShaperable{ -- GitLab