提交 ec13cb12 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #950 from dolphin8/metal

texture
...@@ -7,6 +7,16 @@ ...@@ -7,6 +7,16 @@
objects = { objects = {
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */; };
4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */; };
4AA1EA8A2146631C00D0F791 /* BilinearInterp.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */; };
4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8B2146640900D0F791 /* SplitOp.swift */; };
4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */; };
4AA1EA90214664CD00D0F791 /* Split.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA8F214664CD00D0F791 /* Split.metal */; };
4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; };
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; };
4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */; };
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; }; 4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; };
...@@ -104,6 +114,16 @@ ...@@ -104,6 +114,16 @@
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpOp.swift; sourceTree = "<group>"; };
4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpKernel.swift; sourceTree = "<group>"; };
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.metal; sourceTree = "<group>"; };
4AA1EA8B2146640900D0F791 /* SplitOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SplitOp.swift; sourceTree = "<group>"; };
4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SplitKernel.swift; sourceTree = "<group>"; };
4AA1EA8F214664CD00D0F791 /* Split.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.metal; sourceTree = "<group>"; };
4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; };
4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; };
4AA1EA952146665A00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = "<group>"; }; 4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = "<group>"; };
...@@ -324,6 +344,10 @@ ...@@ -324,6 +344,10 @@
FCBCCC642122FCD700D94F7E /* TransposeOp.swift */, FCBCCC642122FCD700D94F7E /* TransposeOp.swift */,
FCBCCC66212306B000D94F7E /* ConcatOp.swift */, FCBCCC66212306B000D94F7E /* ConcatOp.swift */,
FCBCCC6A2123071700D94F7E /* BoxcoderOp.swift */, FCBCCC6A2123071700D94F7E /* BoxcoderOp.swift */,
4AA1EA8B2146640900D0F791 /* SplitOp.swift */,
4AA1EA91214665D700D0F791 /* ShapeOp.swift */,
4AA1EA972146666500D0F791 /* FlattenOp.swift */,
4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */,
FCBCCC6E2123097100D94F7E /* MulticlassNMSOp.swift */, FCBCCC6E2123097100D94F7E /* MulticlassNMSOp.swift */,
FCDE8A32212A917900F4A8F6 /* ConvTransposeOp.swift */, FCDE8A32212A917900F4A8F6 /* ConvTransposeOp.swift */,
FCEB684B212F093800D2448E /* PreluOp.swift */, FCEB684B212F093800D2448E /* PreluOp.swift */,
...@@ -369,6 +393,10 @@ ...@@ -369,6 +393,10 @@
FCBCCC622122FCC000D94F7E /* TransposeKernel.swift */, FCBCCC622122FCC000D94F7E /* TransposeKernel.swift */,
FCBCCC68212306D300D94F7E /* ConcatKernel.swift */, FCBCCC68212306D300D94F7E /* ConcatKernel.swift */,
FCBCCC6C2123073A00D94F7E /* BoxcoderKernel.swift */, FCBCCC6C2123073A00D94F7E /* BoxcoderKernel.swift */,
4AA1EA8D2146647F00D0F791 /* SplitKernel.swift */,
4AA1EA932146661500D0F791 /* ShapeKernel.swift */,
4AA1EA952146665A00D0F791 /* FlattenKernel.swift */,
4AA1EA87214662BD00D0F791 /* BilinearInterpKernel.swift */,
FCBCCC70212309A700D94F7E /* MulticlassNMSKernel.swift */, FCBCCC70212309A700D94F7E /* MulticlassNMSKernel.swift */,
FCDDC6C5212F9FB800E5EF74 /* PreluKernel.swift */, FCDDC6C5212F9FB800E5EF74 /* PreluKernel.swift */,
); );
...@@ -411,6 +439,8 @@ ...@@ -411,6 +439,8 @@
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
4AA1EA8F214664CD00D0F791 /* Split.metal */,
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */,
4AF9287821341661005B6C3A /* Softmax.metal */, 4AF9287821341661005B6C3A /* Softmax.metal */,
FCEB6849212F00DB00D2448E /* PreluKernel.metal */, FCEB6849212F00DB00D2448E /* PreluKernel.metal */,
FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */, FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */,
...@@ -536,6 +566,7 @@ ...@@ -536,6 +566,7 @@
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FCA67CD7213827AC00BD58AA /* ConvAddBNReluKernel.metal in Sources */, FCA67CD7213827AC00BD58AA /* ConvAddBNReluKernel.metal in Sources */,
4AF9287921341661005B6C3A /* Softmax.metal in Sources */, 4AF9287921341661005B6C3A /* Softmax.metal in Sources */,
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */,
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
...@@ -553,6 +584,7 @@ ...@@ -553,6 +584,7 @@
FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */, FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */,
FCA67CD52138272900BD58AA /* ConvAddMetal.metal in Sources */, FCA67CD52138272900BD58AA /* ConvAddMetal.metal in Sources */,
FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */, FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */,
4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */,
FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */, FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */,
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */, FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */, FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
...@@ -561,6 +593,7 @@ ...@@ -561,6 +593,7 @@
FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */, FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */, FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */,
4AA1EA8A2146631C00D0F791 /* BilinearInterp.metal in Sources */,
FCDDC6CA212FDF6800E5EF74 /* BatchNormKernel.metal in Sources */, FCDDC6CA212FDF6800E5EF74 /* BatchNormKernel.metal in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */, FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
...@@ -573,13 +606,16 @@ ...@@ -573,13 +606,16 @@
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */,
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */,
FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */, FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */,
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */, FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */,
4AA1EA8E2146647F00D0F791 /* SplitKernel.swift in Sources */,
FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */, FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC292C5621421B4600CF622F /* PaddleMobileGPU.m in Sources */, FC292C5621421B4600CF622F /* PaddleMobileGPU.m in Sources */,
4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */,
FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FCD04E6620F314C50007374F /* PoolOp.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */,
...@@ -590,11 +626,13 @@ ...@@ -590,11 +626,13 @@
FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */,
FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */,
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */, FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */,
4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */,
FCBCCC6D2123073A00D94F7E /* BoxcoderKernel.swift in Sources */, FCBCCC6D2123073A00D94F7E /* BoxcoderKernel.swift in Sources */,
FCBCCC69212306D300D94F7E /* ConcatKernel.swift in Sources */, FCBCCC69212306D300D94F7E /* ConcatKernel.swift in Sources */,
FCDDC6C8212FA3CA00E5EF74 /* ConvTransposeKernel.swift in Sources */, FCDDC6C8212FA3CA00E5EF74 /* ConvTransposeKernel.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
FCA3A1652132A5EB00084FE5 /* Common.metal in Sources */, FCA3A1652132A5EB00084FE5 /* Common.metal in Sources */,
4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */,
FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */, FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
...@@ -613,11 +651,13 @@ ...@@ -613,11 +651,13 @@
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */, FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
4AA1EA90214664CD00D0F791 /* Split.metal in Sources */,
FCD04E6820F315020007374F /* PoolKernel.swift in Sources */, FCD04E6820F315020007374F /* PoolKernel.swift in Sources */,
FC0226582138F38D00F395E2 /* PoolKernel.metal in Sources */, FC0226582138F38D00F395E2 /* PoolKernel.metal in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */, FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */,
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */, FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */,
4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */,
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
......
...@@ -60,7 +60,11 @@ class OpCreator<P: PrecisionType> { ...@@ -60,7 +60,11 @@ class OpCreator<P: PrecisionType> {
gTransposeType : TransposeOp<P>.creat, gTransposeType : TransposeOp<P>.creat,
gPriorBoxType : PriorBoxOp<P>.creat, gPriorBoxType : PriorBoxOp<P>.creat,
gPreluType : PreluOp<P>.creat, gPreluType : PreluOp<P>.creat,
gConv2dTransposeType : ConvTransposeOp<P>.creat] gConv2dTransposeType : ConvTransposeOp<P>.creat,
gBilinearInterpType : BilinearInterpOp<P>.creat,
gSplit : SplitOp<P>.creat,
gShape : ShapeOp<P>.creat,
gFlatten : FlattenOp<P>.creat]
private init(){} private init(){}
} }
...@@ -139,7 +139,10 @@ let gConvBnReluType = "conv_bn_relu" ...@@ -139,7 +139,10 @@ let gConvBnReluType = "conv_bn_relu"
let gDwConvBnReluType = "depth_conv_bn_relu" let gDwConvBnReluType = "depth_conv_bn_relu"
let gPreluType = "prelu" let gPreluType = "prelu"
let gConv2dTransposeType = "conv2d_transpose" let gConv2dTransposeType = "conv2d_transpose"
let gBilinearInterpType = "bilinear_interp"
let gSplit = "split"
let gShape = "shape"
let gFlatten = "flatten"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]), gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
...@@ -161,5 +164,9 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out ...@@ -161,5 +164,9 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]), gMulticlassNMSType : (inputs: ["BBoxes", "Scores"], outputs: ["Out"]),
gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]), gPriorBoxType : (inputs: ["Input", "Image"], outputs: ["Boxes", "Variances"]),
gPreluType : (inputs: ["X"], outputs: ["Out"]), gPreluType : (inputs: ["X"], outputs: ["Out"]),
gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]) gConv2dTransposeType : (inputs: ["Input"], outputs: ["Output"]),
gBilinearInterpType : (inputs: ["X"], outputs: ["Out"]),
gSplit : (inputs: ["Input"], outputs: ["Out"]),
gShape : (inputs: ["Input"], outputs: ["Out"]),
gFlatten : (inputs: ["Input"], outputs: ["Out"])
] ]
...@@ -26,7 +26,6 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -26,7 +26,6 @@ class BatchNormParam<P: PrecisionType>: OpParam {
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -39,7 +38,6 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -39,7 +38,6 @@ class BatchNormParam<P: PrecisionType>: OpParam {
let inputVariance: Tensor<ParamPrecisionType> let inputVariance: Tensor<ParamPrecisionType>
let epsilon: Float let epsilon: Float
let momentum: Float let momentum: Float
let is_test: Bool
} }
class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam<P>>, Runable, Creator, InferShaperable{ class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam<P>>, Runable, Creator, InferShaperable{
......
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class BilinearInterpParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope)
// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
// fatalError()
// }
output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope)
out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs)
out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let out_h: Int
let out_w: Int
}
class BilinearInterpOp<P: PrecisionType>: Operator<BilinearInterpKernel<P>, BilinearInterpParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = BilinearInterpOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class FlattenParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
output = try FlattenParam.output(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
var output: Texture<P>
}
class FlattenOp<P: PrecisionType>: Operator<FlattenKernel<P>, FlattenParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable { class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
var newScale: MTLBuffer // var newScale: MTLBuffer
var newBias: MTLBuffer // var newBias: MTLBuffer
//
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>) {
guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { // guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else {
fatalError() // fatalError()
} // }
//
guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { // guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else {
fatalError() // fatalError()
} // }
self.newScale = newScale // self.newScale = newScale
self.newBias = newBias // self.newBias = newBias
//
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
...@@ -36,37 +36,37 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable { ...@@ -36,37 +36,37 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
} else { } else {
fatalError() fatalError()
} }
//
let varianceBuffer : MTLBuffer = param.inputVariance.buffer // let varianceBuffer : MTLBuffer = param.inputVariance.buffer
//
var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length) // var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self) // let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) { // for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot() // invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
} // }
//
let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self) // let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self) // let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
let scale : MTLBuffer = param.inputScale.buffer // let scale : MTLBuffer = param.inputScale.buffer
let scaleContents = scale.contents().assumingMemoryBound(to: P.self) // let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
let bias : MTLBuffer = param.inputBias.buffer // let bias : MTLBuffer = param.inputBias.buffer
let biasContents = bias.contents().assumingMemoryBound(to: P.self) // let biasContents = bias.contents().assumingMemoryBound(to: P.self)
let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self) // let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
//
for i in 0..<(newScale.length / MemoryLayout<P>.stride) { // for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i])) // newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i])) // newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
} // }
} }
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil") throw PaddleMobileError.predictError(message: " encoder is nil")
} }
encoder.setTexture(param.input.metalTexture, index: 0) // encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1) // encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBuffer(newScale, offset: 0, index: 0) // encoder.setBuffer(newScale, offset: 0, index: 0)
encoder.setBuffer(newBias, offset: 0, index: 1) // encoder.setBuffer(newBias, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct BilinearInterpMetalParam {
var ratio_h: Float32
var ratio_w: Float32
}
class BilinearInterpKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: BilinearInterpParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
let ratio_h: Float32 = Float32(param.input.tensorDim.dims[2]) / Float32(param.output.tensorDim.dims[2])
let ratio_w: Float32 = Float32(param.input.tensorDim.dims[3]) / Float32(param.output.tensorDim.dims[3])
var p = BilinearInterpMetalParam.init(ratio_h: ratio_h, ratio_w: ratio_w)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: BilinearInterpParam<P>) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "bilinear_interp")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "bilinear_interp_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FlattenMetalParam {
}
class FlattenKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: FlattenParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.output.metalTexture, index: 0)
encoder.endEncoding()
}
required init(device: MTLDevice, param: FlattenParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ResizeBilinearMetalParam {
var ratio_h: Float32
var ratio_w: Float32
}
class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeBilinearParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
let ratio_h: Float32 = Float32(param.input.tensorDim.dims[2]) / Float32(param.output.tensorDim.dims[2])
let ratio_w: Float32 = Float32(param.input.tensorDim.dims[3]) / Float32(param.output.tensorDim.dims[3])
var p = ResizeBilinearMetalParam.init(ratio_h: ratio_h, ratio_w: ratio_w)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: ResizeBilinearParam<P>) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "resize_bilinear_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ShapeMetalParam {
}
class ShapeKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.output.metalTexture, index: 0)
encoder.endEncoding()
}
required init(device: MTLDevice, param: ShapeParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct SplitMetalParam {
}
class SplitKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: SplitParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.output.metalTexture, index: 0)
encoder.endEncoding()
}
required init(device: MTLDevice, param: SplitParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_half")
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct bilinear_interp_param {
// int32_t out_h;
// int32_t out_w;
float ratio_h;
float ratio_w;
};
kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]],
constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
float w1lambda = w - w0, h1lambda = h - h0;
float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
float4 r0 = input.read(uint2(w0, h0), gid.z);
float4 r1 = input.read(uint2(w1, h0), gid.z);
float4 r2 = input.read(uint2(w0, h1), gid.z);
float4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]],
constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct resize_bilinear_param {
// int32_t out_h;
// int32_t out_w;
float ratio_h;
float ratio_w;
};
kernel void resize_bilinear(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
float w1lambda = w - w0, h1lambda = h - h0;
float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
float4 r0 = input.read(uint2(w0, h0), gid.z);
float4 r1 = input.read(uint2(w1, h0), gid.z);
float4 r2 = input.read(uint2(w0, h1), gid.z);
float4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
kernel void resize_bilinear_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void split(texture2d_array<float, access::write> output[[texture(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
output.write(r, gid.xy, gid.z);
}
kernel void split_half(texture2d_array<half, access::write> output[[texture(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
output.write(half4(r), gid.xy, gid.z);
}
...@@ -44,14 +44,14 @@ class ReshapeParam<P: PrecisionType>: OpParam { ...@@ -44,14 +44,14 @@ class ReshapeParam<P: PrecisionType>: OpParam {
output.padToFourDim = Dim.init(inDim: dim) output.padToFourDim = Dim.init(inDim: dim)
output.dim = output.padToFourDim output.dim = output.padToFourDim
inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs) // inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
} }
let input: Texture<P> let input: Texture<P>
let shape: [Int32] let shape: [Int32]
let inplace: Bool // let inplace: Bool
var output: Texture<P> var output: Texture<P>
} }
......
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class ResizeBilinearParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ResizeBilinearParam.inputX(inputs: opDesc.inputs, from: inScope)
// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
// fatalError()
// }
output = try ResizeBilinearParam.outputOut(outputs: opDesc.outputs, from: inScope)
out_h = try ResizeBilinearParam.getAttr(key: "out_h", attrs: opDesc.attrs)
out_w = try ResizeBilinearParam.getAttr(key: "out_w", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let out_h: Int32
let out_w: Int32
}
class ResizeBilinearOp<P: PrecisionType>: Operator<ResizeBilinearKernel<P>, ResizeBilinearParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = ResizeBilinearOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class ShapeParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
var output: Texture<P>
}
class ShapeOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
class SplitParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
output = try SplitParam.output(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
var output: Texture<P>
}
class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P>
func inferShape() {
// para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
}
}
...@@ -38,6 +38,38 @@ extension InputTexture { ...@@ -38,6 +38,38 @@ extension InputTexture {
} }
} }
/*
4 维 tensor 存储 texture,要考虑 transpose
transpose 之后的维度是 [a, b, c, d],对应的texture_2darray
.width = c
.height = b
.len = a * d + 3 / 4
低于 4 维的 tensor,transpose 必须为 [0, 1, 2, 3] 既不考虑 transpose
// TODO transpose 对于低维 tensor 的扩展原则。。。
// [a, b] -> [1, 1, a, b] transpose 必须为 [0, 1, x, x]
// [a] -> [1, 1, 1, a] transpose 必须为 [0, 1, 2, 3]
// [a, b, c] -> [1, a, b, c] tranpose 必须为 [0, x, x, x]
3 维 tensor [a, b, c] 对应的 texture_2darray,
.width = c
.height = b
.len = a + 3 / 4
2 维 tensor [a, b] 对应的 texture_2darray
.width = b + 3 / 4
.height = a
.len = 1
1 维 tensor [a] 对应的 texture_2darray
.width = a + 3 / 4
.height = 1
.len = 1
*/
public class Texture<P: PrecisionType>: Tensorial { public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim var dim: Dim
public var tensorDim: Dim public var tensorDim: Dim
...@@ -62,6 +94,11 @@ public class Texture<P: PrecisionType>: Tensorial { ...@@ -62,6 +94,11 @@ public class Texture<P: PrecisionType>: Tensorial {
func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) { func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) {
transpose = inTranspose transpose = inTranspose
for i in 0..<(4 - tensorDim.cout()) {
if i != inTranspose[i] {
fatalError()
}
}
let newDim = transpose.map { padToFourDim[$0] } let newDim = transpose.map { padToFourDim[$0] }
let newLayout = transpose.map { layout.layoutWithDim[$0] } let newLayout = transpose.map { layout.layoutWithDim[$0] }
...@@ -70,14 +107,25 @@ public class Texture<P: PrecisionType>: Tensorial { ...@@ -70,14 +107,25 @@ public class Texture<P: PrecisionType>: Tensorial {
dim = Dim.init(inDim: newDim) dim = Dim.init(inDim: newDim)
let tmpTextureDes = MTLTextureDescriptor.init() let tmpTextureDes = MTLTextureDescriptor.init()
tmpTextureDes.width = newDim[2]
// layout.W ?? 1
tmpTextureDes.height = newDim[1]
// layout.H ?? 1
tmpTextureDes.depth = 1
tmpTextureDes.arrayLength = ((newDim[0]) * (newDim[3]) + 3) / 4
tmpTextureDes.textureType = .type2DArray tmpTextureDes.textureType = .type2DArray
tmpTextureDes.depth = 1
switch tensorDim.cout() {
case 4:
tmpTextureDes.width = newDim[2]
tmpTextureDes.height = newDim[1]
tmpTextureDes.arrayLength = ((newDim[0]) * (newDim[3]) + 3) / 4
case 3:
tmpTextureDes.width = newDim[3]
tmpTextureDes.height = newDim[2]
tmpTextureDes.arrayLength = (newDim[1] + 3) / 4
case 2, 1:
tmpTextureDes.width = (newDim[3] + 3) / 4
tmpTextureDes.height = newDim[2]
tmpTextureDes.arrayLength = 1
default:
fatalError("unreachable")
}
if computePrecision == .Float16 { if computePrecision == .Float16 {
tmpTextureDes.pixelFormat = .rgba16Float tmpTextureDes.pixelFormat = .rgba16Float
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册