diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 29ea1ec64c724cfe3070a0599714f6f20453b9b5..fe260157cd907f22d0c851f1ea6e6ca02db109ba 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -7,6 +7,16 @@ objects = { /* Begin PBXBuildFile section */ + 4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */; }; + 4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */; }; + 4AA1EA8A2146631C00D0F791 /* BilinearInterp.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */; }; + 4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8B2146640900D0F791 /* SplitOp.swift */; }; + 4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */; }; + 4AA1EA90214664CD00D0F791 /* Split.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8F214664CD00D0F791 /* Split.metal */; }; + 4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; }; + 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; + 4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */; }; + 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; }; @@ -104,6 +114,16 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpOp.swift; sourceTree = ""; }; + 4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpKernel.swift; sourceTree = ""; }; + 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.metal; sourceTree = ""; }; + 4AA1EA8B2146640900D0F791 /* SplitOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SplitOp.swift; sourceTree = ""; }; + 4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SplitKernel.swift; sourceTree = ""; }; + 4AA1EA8F214664CD00D0F791 /* Split.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.metal; sourceTree = ""; }; + 4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = ""; }; + 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = ""; }; + 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; + 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; 4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = ""; }; @@ -324,6 +344,10 @@ FCBCCC642122FCD700D94F7E /* TransposeOp.swift */, FCBCCC66212306B000D94F7E /* ConcatOp.swift */, FCBCCC6A2123071700D94F7E /* BoxcoderOp.swift */, + 4AA1EA8B2146640900D0F791 /* SplitOp.swift */, + 4AA1EA91214665D700D0F791 /* ShapeOp.swift */, + 4AA1EA972146666500D0F791 /* FlattenOp.swift */, + 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */, FCBCCC6E2123097100D94F7E /* MulticlassNMSOp.swift */, FCDE8A32212A917900F4A8F6 /* ConvTransposeOp.swift */, FCEB684B212F093800D2448E /* PreluOp.swift */, @@ -369,6 +393,10 @@ FCBCCC622122FCC000D94F7E /* TransposeKernel.swift */, FCBCCC68212306D300D94F7E /* ConcatKernel.swift */, FCBCCC6C2123073A00D94F7E /* BoxcoderKernel.swift */, + 4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */, + 4AA1EA932146661500D0F791 /* ShapeKernel.swift */, + 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */, + 4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */, FCBCCC70212309A700D94F7E /* MulticlassNMSKernel.swift */, FCDDC6C5212F9FB800E5EF74 /* PreluKernel.swift */, ); @@ -411,6 +439,8 @@ FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */, + 4AA1EA8F214664CD00D0F791 /* Split.metal */, + 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */, 4AF9287821341661005B6C3A /* Softmax.metal */, FCEB6849212F00DB00D2448E /* PreluKernel.metal */, FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */, @@ -536,6 +566,7 @@ FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, FCA67CD7213827AC00BD58AA /* ConvAddBNReluKernel.metal in Sources */, 4AF9287921341661005B6C3A /* Softmax.metal in Sources */, + 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */, FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, @@ -553,6 +584,7 @@ FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */, FCA67CD52138272900BD58AA /* ConvAddMetal.metal in Sources */, FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */, + 4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */, FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */, FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */, FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */, @@ -561,6 +593,7 @@ FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */, + 4AA1EA8A2146631C00D0F791 /* BilinearInterp.metal in Sources */, FCDDC6CA212FDF6800E5EF74 /* BatchNormKernel.metal in Sources */, FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, @@ -573,13 +606,16 @@ FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */, + 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */, FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */, + 4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */, FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC292C5621421B4600CF622F /* PaddleMobileGPU.m in Sources */, + 4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */, FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, @@ -590,11 +626,13 @@ FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */, + 4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */, FCBCCC6D2123073A00D94F7E /* BoxcoderKernel.swift in Sources */, FCBCCC69212306D300D94F7E /* ConcatKernel.swift in Sources */, FCDDC6C8212FA3CA00E5EF74 /* ConvTransposeKernel.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, FCA3A1652132A5EB00084FE5 /* Common.metal in Sources */, + 4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */, FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, @@ -613,11 +651,13 @@ FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */, + 4AA1EA90214664CD00D0F791 /* Split.metal in Sources */, FCD04E6820F315020007374F /* PoolKernel.swift in Sources */, FC0226582138F38D00F395E2 /* PoolKernel.metal in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */, FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */, + 4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift index e7e542e9b70e6a6f35e8aac60b7a165d55dc9139..68763feef8e347cdfa3b7be5096aadc67fb93084 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift @@ -61,7 +61,10 @@ class OpCreator { gPriorBoxType : PriorBoxOp

.creat, gPreluType : PreluOp

.creat, gConv2dTransposeType : ConvTransposeOp

.creat, - gResizeBilinearType : ResizeBilinearOp

.creat] + gBilinearInterpType : BilinearInterpOp

.creat, + gSplit : SplitOp

.creat, + gShape : ShapeOp

.creat, + gFlatten : FlattenOp

.creat] private init(){} } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index c3990479e488ad7d170965a4224972b9278e22f1..0bd6b3692d0f983a8b154bec9727468942ef7a51 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -139,8 +139,10 @@ let gConvBnReluType = "conv_bn_relu" let gDwConvBnReluType = "depth_conv_bn_relu" let gPreluType = "prelu" let gConv2dTransposeType = "conv2d_transpose" -let gResizeBilinearType = "resize_bilinear" - +let gBilinearInterpType = "bilinear_interp" +let gSplit = "split" +let gShape = "shape" +let gFlatten = "flatten" let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), gBatchNormType : (inputs: ["X"], outputs: ["Y"]), @@ -163,5 +165,8 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]), gPreluType : (inputs: ["X"], outputs: ["Out"]), gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]), - gResizeBilinearType : (inputs: ["X"], outputs: ["Out"]) + gBilinearInterpType : (inputs: ["X"], outputs: ["Out"]), + gSplit : (inputs: ["Input"], outputs: ["Out"]), + gShape : (inputs: ["Input"], outputs: ["Out"]), + gFlatten : (inputs: ["Input"], outputs: ["Out"]) ] diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 68441244a3915203a8909166d58aee172483364a..cd7bebaf40204affc2009258af5894b7a2cc40ec 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -26,7 +26,6 @@ class BatchNormParam: OpParam { inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) - is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs) } catch let error { throw error } @@ -39,7 +38,6 @@ class BatchNormParam: OpParam { let inputVariance: Tensor let epsilon: Float let momentum: Float - let is_test: Bool } class BatchNormOp: Operator, BatchNormParam

>, Runable, Creator, InferShaperable{ diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BilinearInterpOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BilinearInterpOp.swift new file mode 100644 index 0000000000000000000000000000000000000000..e7f0db22312e8d49505513290bd21a6695d65790 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/BilinearInterpOp.swift @@ -0,0 +1,64 @@ +///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. */ + +import Foundation + +class BilinearInterpParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope) +// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) { +// fatalError() +// } + output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope) + out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs) + out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs) + } catch let error { + throw error + } + } + let input: Texture

+ var output: Texture

+ let out_h: Int + let out_w: Int +} + +class BilinearInterpOp: Operator, BilinearInterpParam

>, Runable, Creator, InferShaperable{ + + typealias OpType = BilinearInterpOp

+ + func inferShape() { + // para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + } + +} + + + + + + diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift new file mode 100644 index 0000000000000000000000000000000000000000..a7e92bdff7161a41a8c278176e710474898814ef --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/FlattenOp.swift @@ -0,0 +1,55 @@ +///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. */ + +import Foundation + +class FlattenParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + output = try FlattenParam.output(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + var output: Texture

+} + +class FlattenOp: Operator, FlattenParam

>, Runable, Creator, InferShaperable{ + + typealias OpType = SplitOp

+ + func inferShape() { + // para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + } + +} + + + + + + diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift index 80145e677beb3f37e92257a56dc4e3c7b337e787..b80a47516e083b3f5f303202b0e5f08d6c796a65 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift @@ -15,20 +15,20 @@ import Foundation class BatchNormKernel: Kernel, Computable { - var newScale: MTLBuffer - var newBias: MTLBuffer - +// var newScale: MTLBuffer +// var newBias: MTLBuffer +// required init(device: MTLDevice, param: BatchNormParam

) { - guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { - fatalError() - } - - guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { - fatalError() - } - self.newScale = newScale - self.newBias = newBias - +// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { +// fatalError() +// } +// +// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { +// fatalError() +// } +// self.newScale = newScale +// self.newBias = newBias +// if computePrecision == .Float32 { super.init(device: device, inFunctionName: "batchnorm") } else if computePrecision == .Float16 { @@ -36,37 +36,37 @@ class BatchNormKernel: Kernel, Computable { } else { fatalError() } - - let varianceBuffer : MTLBuffer = param.inputVariance.buffer - - var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) - let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) - for i in 0..<(varianceBuffer.length / MemoryLayout

.stride) { - invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot() - } - - let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) - let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) - let scale : MTLBuffer = param.inputScale.buffer - let scaleContents = scale.contents().assumingMemoryBound(to: P.self) - let bias : MTLBuffer = param.inputBias.buffer - let biasContents = bias.contents().assumingMemoryBound(to: P.self) - let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) - - for i in 0..<(newScale.length / MemoryLayout

.stride) { - newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i])) - newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i])) - } +// +// let varianceBuffer : MTLBuffer = param.inputVariance.buffer +// +// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) +// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) +// for i in 0..<(varianceBuffer.length / MemoryLayout

.stride) { +// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot() +// } +// +// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) +// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) +// let scale : MTLBuffer = param.inputScale.buffer +// let scaleContents = scale.contents().assumingMemoryBound(to: P.self) +// let bias : MTLBuffer = param.inputBias.buffer +// let biasContents = bias.contents().assumingMemoryBound(to: P.self) +// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) +// +// for i in 0..<(newScale.length / MemoryLayout

.stride) { +// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i])) +// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i])) +// } } func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam

) throws { guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encoder is nil") } - encoder.setTexture(param.input.metalTexture, index: 0) - encoder.setTexture(param.output.metalTexture, index: 1) - encoder.setBuffer(newScale, offset: 0, index: 0) - encoder.setBuffer(newBias, offset: 0, index: 1) +// encoder.setTexture(param.input.metalTexture, index: 0) +// encoder.setTexture(param.output.metalTexture, index: 1) +// encoder.setBuffer(newScale, offset: 0, index: 0) +// encoder.setBuffer(newBias, offset: 0, index: 1) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.endEncoding() } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..ab6a44187f75fdee9484026ec859347b6c6166dc --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BilinearInterpKernel.swift @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct BilinearInterpMetalParam { + var ratio_h: Float32 + var ratio_w: Float32 +} + +class BilinearInterpKernel: Kernel, Computable{ + func compute(commandBuffer: MTLCommandBuffer, param: BilinearInterpParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + let ratio_h: Float32 = Float32(param.input.tensorDim.dims[2]) / Float32(param.output.tensorDim.dims[2]) + let ratio_w: Float32 = Float32(param.input.tensorDim.dims[3]) / Float32(param.output.tensorDim.dims[3]) + var p = BilinearInterpMetalParam.init(ratio_h: ratio_h, ratio_w: ratio_w) + encoder.setBytes(&p, length: MemoryLayout.size, index: 0) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: BilinearInterpParam

) { + param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) + if computePrecision == .Float32 { + super.init(device: device, inFunctionName: "bilinear_interp") + } else if computePrecision == .Float16 { + super.init(device: device, inFunctionName: "bilinear_interp_half") + } else { + fatalError() + } + } + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..87c317f68ad22995e7981537ef7be29fb9a19cc5 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/FlattenKernel.swift @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct FlattenMetalParam { +} + +class FlattenKernel: Kernel, Computable{ + func compute(commandBuffer: MTLCommandBuffer, param: FlattenParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.output.metalTexture, index: 0) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: FlattenParam

) { + param.output.initTexture(device: device, computePrecision: computePrecision) + if computePrecision == .Float32 { + super.init(device: device, inFunctionName: "split") + } else if computePrecision == .Float16 { + super.init(device: device, inFunctionName: "split_half") + } else { + fatalError() + } + } + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..2efcd45da4b717dbabdb918d95df64d2bc9b174b --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ShapeKernel.swift @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct ShapeMetalParam { +} + +class ShapeKernel: Kernel, Computable{ + func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.output.metalTexture, index: 0) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: ShapeParam

) { + param.output.initTexture(device: device, computePrecision: computePrecision) + if computePrecision == .Float32 { + super.init(device: device, inFunctionName: "split") + } else if computePrecision == .Float16 { + super.init(device: device, inFunctionName: "split_half") + } else { + fatalError() + } + } + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..a25a70064045a17bb46a22fbbddf824f1d99e51c --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/SplitKernel.swift @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +struct SplitMetalParam { +} + +class SplitKernel: Kernel, Computable{ + func compute(commandBuffer: MTLCommandBuffer, param: SplitParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + encoder.setTexture(param.output.metalTexture, index: 0) + encoder.endEncoding() + } + + required init(device: MTLDevice, param: SplitParam

) { + param.output.initTexture(device: device, computePrecision: computePrecision) + if computePrecision == .Float32 { + super.init(device: device, inFunctionName: "split") + } else if computePrecision == .Float16 { + super.init(device: device, inFunctionName: "split_half") + } else { + fatalError() + } + } + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal new file mode 100644 index 0000000000000000000000000000000000000000..14b3882e0d18e9bced31263e1f178fd8b9b971f2 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/BilinearInterp.metal @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +struct bilinear_interp_param { +// int32_t out_h; +// int32_t out_w; + float ratio_h; + float ratio_w; +}; + +kernel void bilinear_interp(texture2d_array input [[texture(0)]], + texture2d_array output [[texture(2)]], + constant bilinear_interp_param & pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + float4 r; + if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { + r = input.read(gid.xy, gid.z); + } else { + float w = gid.x * pm.ratio_w; + float h = gid.y * pm.ratio_h; + uint w0 = w, h0 = h; + uint w1 = w0 + 1, h1 = h0 + 1; + float w1lambda = w - w0, h1lambda = h - h0; + float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; + if (w1 >= input.get_width()) w1 = w0; + if (h1 >= input.get_height()) h1 = h0; + float4 r0 = input.read(uint2(w0, h0), gid.z); + float4 r1 = input.read(uint2(w1, h0), gid.z); + float4 r2 = input.read(uint2(w0, h1), gid.z); + float4 r3 = input.read(uint2(w1, h1), gid.z); + r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); + } + output.write(r, gid.xy, gid.z); +} + +kernel void bilinear_interp_half(texture2d_array input [[texture(0)]], + texture2d_array output [[texture(2)]], + constant bilinear_interp_param & pm [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + + half4 r; + if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { + r = input.read(gid.xy, gid.z); + } else { + half w = gid.x * pm.ratio_w; + half h = gid.y * pm.ratio_h; + uint w0 = w, h0 = h; + uint w1 = w0 + 1, h1 = h0 + 1; + half w1lambda = w - w0, h1lambda = h - h0; + half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; + if (w1 >= input.get_width()) w1 = w0; + if (h1 >= input.get_height()) h1 = h0; + half4 r0 = input.read(uint2(w0, h0), gid.z); + half4 r1 = input.read(uint2(w1, h0), gid.z); + half4 r2 = input.read(uint2(w0, h1), gid.z); + half4 r3 = input.read(uint2(w1, h1), gid.z); + r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); + } + output.write(r, gid.xy, gid.z); + output.write(r, gid.xy, gid.z); +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ResizeBilinear.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ResizeBilinear.metal index 4adfce0d151ee74baac79638936b443e438e822d..fbb4e12cb82c12f8dc5b94c397e43b8c8c5ae518 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ResizeBilinear.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ResizeBilinear.metal @@ -28,7 +28,7 @@ kernel void resize_bilinear(texture2d_array input [[texture uint3 gid [[thread_position_in_grid]]) { float4 r; if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { - r = input.read(gid.xy, gid.z) + r = input.read(gid.xy, gid.z); } else { float w = gid.x * pm.ratio_w; float h = gid.y * pm.ratio_h; @@ -42,7 +42,7 @@ kernel void resize_bilinear(texture2d_array input [[texture float4 r1 = input.read(uint2(w1, h0), gid.z); float4 r2 = input.read(uint2(w0, h1), gid.z); float4 r3 = input.read(uint2(w1, h1), gid.z); - r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r3 + w1lambda * r4); + r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); } output.write(r, gid.xy, gid.z); } @@ -54,7 +54,7 @@ kernel void resize_bilinear_half(texture2d_array input [[tex half4 r; if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { - r = input.read(gid.xy, gid.z) + r = input.read(gid.xy, gid.z); } else { half w = gid.x * pm.ratio_w; half h = gid.y * pm.ratio_h; @@ -68,7 +68,7 @@ kernel void resize_bilinear_half(texture2d_array input [[tex half4 r1 = input.read(uint2(w1, h0), gid.z); half4 r2 = input.read(uint2(w0, h1), gid.z); half4 r3 = input.read(uint2(w1, h1), gid.z); - r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r3 + w1lambda * r4); + r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); } output.write(r, gid.xy, gid.z); output.write(r, gid.xy, gid.z); diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal new file mode 100644 index 0000000000000000000000000000000000000000..ccdaf47583d88302489f3a9d3c6922d454825b8a --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/Split.metal @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +using namespace metal; + +kernel void split(texture2d_array output[[texture(0)]], + uint3 gid [[thread_position_in_grid]]) { + float4 r; + + output.write(r, gid.xy, gid.z); +} + +kernel void split_half(texture2d_array output[[texture(0)]], + uint3 gid [[thread_position_in_grid]]) { + float4 r; + + output.write(half4(r), gid.xy, gid.z); +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift index 3fd9ebfb883d43c51b5ede4f4c6d91b8a59cbeda..1c1da9901d5740558d8cfd6363f2e96b15728556 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift @@ -44,14 +44,14 @@ class ReshapeParam: OpParam { output.padToFourDim = Dim.init(inDim: dim) output.dim = output.padToFourDim - inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs) +// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs) } catch let error { throw error } } let input: Texture

let shape: [Int32] - let inplace: Bool +// let inplace: Bool var output: Texture

} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ResizeBilinearOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ResizeBilinearOp.swift index 6f1b361811604b9b4fb538499e036acd67b0d931..e0e699cdb8b3a17eb109877f1a7bd986b5e07403 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ResizeBilinearOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ResizeBilinearOp.swift @@ -19,9 +19,9 @@ class ResizeBilinearParam: OpParam { required init(opDesc: OpDesc, inScope: Scope) throws { do { input = try ResizeBilinearParam.inputX(inputs: opDesc.inputs, from: inScope) - if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) { - fatalError() - } +// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) { +// fatalError() +// } output = try ResizeBilinearParam.outputOut(outputs: opDesc.outputs, from: inScope) out_h = try ResizeBilinearParam.getAttr(key: "out_h", attrs: opDesc.attrs) out_w = try ResizeBilinearParam.getAttr(key: "out_w", attrs: opDesc.attrs) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift new file mode 100644 index 0000000000000000000000000000000000000000..7af5562040d86d8c1b0989344650803d6c32975f --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/ShapeOp.swift @@ -0,0 +1,55 @@ +///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. */ + +import Foundation + +class ShapeParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + var output: Texture

+} + +class ShapeOp: Operator, SplitParam

>, Runable, Creator, InferShaperable{ + + typealias OpType = SplitOp

+ + func inferShape() { + // para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + } + +} + + + + + + diff --git a/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift new file mode 100644 index 0000000000000000000000000000000000000000..6ea783b55a206454e304ad0e117237b05a634c4d --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/SplitOp.swift @@ -0,0 +1,55 @@ +///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. */ + +import Foundation + +class SplitParam: OpParam { + typealias ParamPrecisionType = P + required init(opDesc: OpDesc, inScope: Scope) throws { + do { + output = try SplitParam.output(outputs: opDesc.outputs, from: inScope) + } catch let error { + throw error + } + } + var output: Texture

+} + +class SplitOp: Operator, SplitParam

>, Runable, Creator, InferShaperable{ + + typealias OpType = SplitOp

+ + func inferShape() { + // para.output.dim = para.input.dim + } + + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } + } + + func delogOutput() { + print(" \(type) output: ") + } + +} + + + + + + diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index c6e8842182116001ceaa89510389cca29f1f7d7e..194d3d3015754cd2faf2dc3f4b4b098d762f2e53 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -38,6 +38,38 @@ extension InputTexture { } } + +/* + 4 维 tensor 存储 texture,要考虑 transpose + transpose 之后的维度是 [a, b, c, d],对应的texture_2darray + .width = c + .height = b + .len = a * d + 3 / 4 + +低于 4 维的 tensor,transpose 必须为 [0, 1, 2, 3] 既不考虑 transpose + +// TODO transpose 对于低维 tensor 的扩展原则。。。 +// [a, b] -> [1, 1, a, b] transpose 必须为 [0, 1, x, x] +// [a] -> [1, 1, 1, a] transpose 必须为 [0, 1, 2, 3] +// [a, b, c] -> [1, a, b, c] tranpose 必须为 [0, x, x, x] + +3 维 tensor [a, b, c] 对应的 texture_2darray, +.width = c +.height = b +.len = a + 3 / 4 + + 2 维 tensor [a, b] 对应的 texture_2darray + .width = b + 3 / 4 + .height = a + .len = 1 + + 1 维 tensor [a] 对应的 texture_2darray + .width = a + 3 / 4 + .height = 1 + .len = 1 + */ + + public class Texture: Tensorial { var dim: Dim public var tensorDim: Dim @@ -62,6 +94,11 @@ public class Texture: Tensorial { func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) { transpose = inTranspose + for i in 0..<(4 - tensorDim.cout()) { + if i != inTranspose[i] { + fatalError() + } + } let newDim = transpose.map { padToFourDim[$0] } let newLayout = transpose.map { layout.layoutWithDim[$0] } @@ -70,14 +107,25 @@ public class Texture: Tensorial { dim = Dim.init(inDim: newDim) let tmpTextureDes = MTLTextureDescriptor.init() - - tmpTextureDes.width = newDim[2] - // layout.W ?? 1 - tmpTextureDes.height = newDim[1] - // layout.H ?? 1 - tmpTextureDes.depth = 1 - tmpTextureDes.arrayLength = ((newDim[0]) * (newDim[3]) + 3) / 4 tmpTextureDes.textureType = .type2DArray + tmpTextureDes.depth = 1 + + switch tensorDim.cout() { + case 4: + tmpTextureDes.width = newDim[2] + tmpTextureDes.height = newDim[1] + tmpTextureDes.arrayLength = ((newDim[0]) * (newDim[3]) + 3) / 4 + case 3: + tmpTextureDes.width = newDim[3] + tmpTextureDes.height = newDim[2] + tmpTextureDes.arrayLength = (newDim[1] + 3) / 4 + case 2, 1: + tmpTextureDes.width = (newDim[3] + 3) / 4 + tmpTextureDes.height = newDim[2] + tmpTextureDes.arrayLength = 1 + default: + fatalError("unreachable") + } if computePrecision == .Float16 { tmpTextureDes.pixelFormat = .rgba16Float