提交 45c202e5 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #970 from dolphin8/metal

Metal
...@@ -16,9 +16,13 @@ ...@@ -16,9 +16,13 @@
4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; }; 4AA1EA92214665D700D0F791 /* ShapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA91214665D700D0F791 /* ShapeOp.swift */; };
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; };
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; };
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */; };
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */; };
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* Concat.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* Concat.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; }; 4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; };
D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; }; D3831F70E7E0B565B9AC22DA /* Pods_paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */; };
FC0226562138F33800F395E2 /* TransposeKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC0226552138F33800F395E2 /* TransposeKernel.metal */; }; FC0226562138F33800F395E2 /* TransposeKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC0226552138F33800F395E2 /* TransposeKernel.metal */; };
...@@ -124,9 +128,13 @@ ...@@ -124,9 +128,13 @@
4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; }; 4AA1EA91214665D700D0F791 /* ShapeOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeOp.swift; sourceTree = "<group>"; };
4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; }; 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; };
4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; };
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.inc.metal; sourceTree = "<group>"; };
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.inc.metal; sourceTree = "<group>"; };
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* Concat.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Concat.metal; sourceTree = "<group>"; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; }; 4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; };
CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; }; CDF58151D902A1CBAE56A0C2 /* Pods-paddle-mobile.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile/Pods-paddle-mobile.debug.xcconfig"; sourceTree = "<group>"; };
DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; DD2E06330A1E7129C918DB46 /* Pods_paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
...@@ -391,6 +399,7 @@ ...@@ -391,6 +399,7 @@
FCD04E6720F315020007374F /* PoolKernel.swift */, FCD04E6720F315020007374F /* PoolKernel.swift */,
FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */, FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */,
FCD04E6F20F31B720007374F /* ReshapeKernel.swift */, FCD04E6F20F31B720007374F /* ReshapeKernel.swift */,
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */,
FCD04E7320F3437E0007374F /* ConvAddKernel.swift */, FCD04E7320F3437E0007374F /* ConvAddKernel.swift */,
FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */, FCBCCC5A2122F66F00D94F7E /* ConvBNReluKernel.swift */,
FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */, FCBCCC602122FBDF00D94F7E /* PriorBoxKernel.swift */,
...@@ -437,12 +446,14 @@ ...@@ -437,12 +446,14 @@
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC27990D21341016000B6BAD /* BoxCoder.metal */, FC27990D21341016000B6BAD /* BoxCoder.metal */,
4AF928812135673D005B6C3A /* Concat.metal */, 4AF928812135673D005B6C3A /* ConcatKernel.metal */,
4AA1EA9D2148D6F900D0F791 /* ConcatKernel.inc.metal */,
4AF9288321357BE3005B6C3A /* Elementwise.metal */, 4AF9288321357BE3005B6C3A /* Elementwise.metal */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */, FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
4AF928762133F1DB005B6C3A /* BoxCoder.metal */, 4AF928762133F1DB005B6C3A /* BoxCoder.metal */,
4AA1EA8F214664CD00D0F791 /* Split.metal */, 4AA1EA8F214664CD00D0F791 /* Split.metal */,
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */,
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */, 4AA1EA892146631C00D0F791 /* BilinearInterp.metal */,
4AF9287821341661005B6C3A /* Softmax.metal */, 4AF9287821341661005B6C3A /* Softmax.metal */,
FCEB6849212F00DB00D2448E /* PreluKernel.metal */, FCEB6849212F00DB00D2448E /* PreluKernel.metal */,
...@@ -450,6 +461,7 @@ ...@@ -450,6 +461,7 @@
FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */, FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */,
FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */, FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */,
FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */, FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */,
4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.inc.metal */,
FCA3A1642132A5EB00084FE5 /* Common.metal */, FCA3A1642132A5EB00084FE5 /* Common.metal */,
FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */, FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */,
FCA67CD42138272900BD58AA /* ConvAddMetal.metal */, FCA67CD42138272900BD58AA /* ConvAddMetal.metal */,
...@@ -471,6 +483,7 @@ ...@@ -471,6 +483,7 @@
FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */, FC4FD9792140E4980073E130 /* PaddleMobile.h in Headers */,
FC292C85214257CB00CF622F /* CPUCompute.h in Headers */, FC292C85214257CB00CF622F /* CPUCompute.h in Headers */,
FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */, FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */,
4AA1EA9E2148D6F900D0F791 /* ConcatKernel.inc.metal in Headers */,
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */, FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
...@@ -574,6 +587,7 @@ ...@@ -574,6 +587,7 @@
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */, FCDE8A33212A917900F4A8F6 /* ConvTransposeOp.swift in Sources */,
FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */, FCBCCC6B2123071700D94F7E /* BoxcoderOp.swift in Sources */,
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */, 4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */, FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
...@@ -610,6 +624,7 @@ ...@@ -610,6 +624,7 @@
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */,
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */,
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */,
FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
...@@ -624,7 +639,7 @@ ...@@ -624,7 +639,7 @@
FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC292C872142624800CF622F /* Genet.swift in Sources */, FC292C872142624800CF622F /* Genet.swift in Sources */,
4AF928822135673D005B6C3A /* Concat.metal in Sources */, 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */,
FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */, FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */,
FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */, FCBCCC71212309A700D94F7E /* MulticlassNMSKernel.swift in Sources */,
FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */, FCDC0FEB21099A1D00DC9EFB /* Tools.swift in Sources */,
...@@ -650,6 +665,7 @@ ...@@ -650,6 +665,7 @@
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */, FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */,
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */,
FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */, FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */,
FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */, FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
......
...@@ -83,38 +83,38 @@ public class PaddleMobileUnitTest { ...@@ -83,38 +83,38 @@ public class PaddleMobileUnitTest {
} }
public func testConcat() { public func testConcat() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil" // let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
var it: [[Float32]] = [] // var it: [[Float32]] = []
for _ in 0..<7 { // for _ in 0..<7 {
it.append((0..<12).map { Float32($0) }) // it.append((0..<12).map { Float32($0) })
} // }
let input = it.map { device.tensor2texture(value: $0, dim: [3, 4]) } // let input = it.map { device.tensor2texture(value: $0, dim: [3, 4]) }
let output = device.tensor2texture(value: [Float32](), dim: [3, 28]) // let output = device.tensor2texture(value: [Float32](), dim: [3, 28])
//
let param = ConcatTestParam.init( // let param = ConcatTestParam.init(
input: input, // input: input,
output: output, // output: output,
dims: [[3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]], // dims: [[3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]],
axis: 1, // axis: 1,
odim: [3, 28] // odim: [3, 28]
) // )
let concatKernel = ConcatKernel<Float32>.init(device: device, testParam: param) // let concatKernel = ConcatKernel<Float32>.init(device: device, testParam: param)
concatKernel.test(cmdBuffer: buffer, param: param) // concatKernel.test(cmdBuffer: buffer, param: param)
buffer.addCompletedHandler { (buffer) in // buffer.addCompletedHandler { (buffer) in
for i in 0..<it.count { // for i in 0..<it.count {
let _: Float32? = input[i].logDesc() // let _: Float32? = input[i].logDesc()
self.tensorPrint(tensor: it[i], dim: [3, 4]) // self.tensorPrint(tensor: it[i], dim: [3, 4])
} // }
let _: Float32? = output.logDesc() // let _: Float32? = output.logDesc()
let tx: [Float32] = self.device.texture2tensor(texture: output, dim: [3, 28]) // let tx: [Float32] = self.device.texture2tensor(texture: output, dim: [3, 28])
self.tensorPrint(tensor: tx, dim: [3, 28]) // self.tensorPrint(tensor: tx, dim: [3, 28])
} // }
//
buffer.commit() // buffer.commit()
} }
public func testReshape() { public func testReshape() {
let buffer = queue.makeCommandBuffer() ?! "buffer is nil" // let buffer = queue.makeCommandBuffer() ?! "buffer is nil"
// let input: [Float32] = (0..<24).map { Float32($0) } // let input: [Float32] = (0..<24).map { Float32($0) }
// let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4]) // let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
// let outTexture = device.tensor2texture(value: [Float32](), dim: [4, 6]) // let outTexture = device.tensor2texture(value: [Float32](), dim: [4, 6])
...@@ -139,32 +139,32 @@ public class PaddleMobileUnitTest { ...@@ -139,32 +139,32 @@ public class PaddleMobileUnitTest {
// self.tensorPrint(tensor: tx, dim: [4, 6]) // self.tensorPrint(tensor: tx, dim: [4, 6])
// } // }
let input: [Float32] = (0..<24).map { Float32($0) } // let input: [Float32] = (0..<24).map { Float32($0) }
let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4]) // let inTexture = device.tensor2texture(value: input, dim: [2, 3, 4])
let outTexture = device.tensor2texture(value: [Float32](), dim: [24]) // let outTexture = device.tensor2texture(value: [Float32](), dim: [24])
let mp = ReshapeMetalParam.init( // let mp = ReshapeMetalParam.init(
idim: (1, 2, 3, 4), // idim: (1, 2, 3, 4),
itrans: (0, 1, 2, 3), // itrans: (0, 1, 2, 3),
odim: (1, 1, 1, 24), // odim: (1, 1, 1, 24),
otrans: (0, 1, 2, 3) // otrans: (0, 1, 2, 3)
) // )
let param = ReshapeTestParam.init( // let param = ReshapeTestParam.init(
inputTexture: inTexture, // inputTexture: inTexture,
outputTexture: outTexture, // outputTexture: outTexture,
param: mp // param: mp
) // )
let reshapeKernel = ReshapeKernel<Float32>.init(device: device, testParam: param) // let reshapeKernel = ReshapeKernel<Float32>.init(device: device, testParam: param)
reshapeKernel.test(commandBuffer: buffer, testParam: param) // reshapeKernel.test(commandBuffer: buffer, testParam: param)
buffer.addCompletedHandler { (buffer) in // buffer.addCompletedHandler { (buffer) in
let _: Float32? = inTexture.logDesc() // let _: Float32? = inTexture.logDesc()
let _: Float32? = outTexture.logDesc() // let _: Float32? = outTexture.logDesc()
self.tensorPrint(tensor: input, dim: [2, 3, 4]) // self.tensorPrint(tensor: input, dim: [2, 3, 4])
let tx: [Float32] = self.device.texture2tensor(texture: outTexture, dim: [24]) // let tx: [Float32] = self.device.texture2tensor(texture: outTexture, dim: [24])
self.tensorPrint(tensor: tx, dim: [24]) // self.tensorPrint(tensor: tx, dim: [24])
} // }
//
//
buffer.commit() // buffer.commit()
} }
public func testTranspose() { public func testTranspose() {
......
...@@ -30,7 +30,7 @@ public class MobileNet_ssd_AR: Net{ ...@@ -30,7 +30,7 @@ public class MobileNet_ssd_AR: Net{
class MobilenetssdPreProccess: CusomKernel { class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = CusomKernel.Shape.init(inWidth: 160, inHeight: 160, inChannel: 3) let s = CusomKernel.Shape.init(inWidth: 160, inHeight: 160, inChannel: 3)
super.init(device: device, inFunctionName: "mobilent_ar_preprocess_half", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilent_ar_preprocess", outputDim: s, usePaddleMobileLib: false)
} }
} }
......
...@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -19,11 +19,14 @@ class BatchNormParam<P: PrecisionType>: OpParam {
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope) input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
if input.transpose != [0, 2, 3, 1] {
fatalError("batch norm only accepts NHWC")
}
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope) output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope) bias = try BatchNormParam.getFirstTensor(key: "Bias", map: opDesc.paraInputs, from: inScope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope) mean = try BatchNormParam.getFirstTensor(key: "Mean", map: opDesc.paraInputs, from: inScope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope) scale = try BatchNormParam.getFirstTensor(key: "Scale", map: opDesc.paraInputs, from: inScope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope) variance = try BatchNormParam.getFirstTensor(key: "Variance", map: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
} catch let error { } catch let error {
...@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam { ...@@ -32,10 +35,10 @@ class BatchNormParam<P: PrecisionType>: OpParam {
} }
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
let inputBias: Tensor<ParamPrecisionType> let bias: Tensor<P>
let inputMean: Tensor<ParamPrecisionType> let mean: Tensor<P>
let inputScale: Tensor<ParamPrecisionType> let scale: Tensor<P>
let inputVariance: Tensor<ParamPrecisionType> let variance: Tensor<P>
let epsilon: Float let epsilon: Float
let momentum: Float let momentum: Float
} }
......
...@@ -19,15 +19,15 @@ class BilinearInterpParam<P: PrecisionType>: OpParam { ...@@ -19,15 +19,15 @@ class BilinearInterpParam<P: PrecisionType>: OpParam {
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope) input = try BilinearInterpParam.inputX(inputs: opDesc.inputs, from: inScope)
// if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
// fatalError()
// }
output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope) output = try BilinearInterpParam.outputOut(outputs: opDesc.outputs, from: inScope)
out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs) out_h = try BilinearInterpParam.getAttr(key: "out_h", attrs: opDesc.attrs)
out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs) out_w = try BilinearInterpParam.getAttr(key: "out_w", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
if (input.transpose != [0, 2, 3, 1]) || (input.tensorDim.cout() != 4) {
fatalError()
}
} }
let input: Texture<P> let input: Texture<P>
var output: Texture<P> var output: Texture<P>
...@@ -53,6 +53,15 @@ class BilinearInterpOp<P: PrecisionType>: Operator<BilinearInterpKernel<P>, Bili ...@@ -53,6 +53,15 @@ class BilinearInterpOp<P: PrecisionType>: Operator<BilinearInterpKernel<P>, Bili
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] {
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3])).strideArray())
} else {
fatalError(" not implemet")
}
} }
} }
......
...@@ -14,7 +14,24 @@ ...@@ -14,7 +14,24 @@
import Foundation import Foundation
class FlattenOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{ class FlattenParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try FlattenParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try FlattenParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try FlattenParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let axis: Int
}
class FlattenOp<P: PrecisionType>: Operator<FlattenKernel<P>, FlattenParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = FlattenOp<P> typealias OpType = FlattenOp<P>
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable { class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
// var newScale: MTLBuffer
// var newBias: MTLBuffer
//
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>) {
// guard let newScale = device.makeBuffer(length: param.inputScale.buffer.length) else { let count = param.variance.dim.numel()
// fatalError() let varianceP = param.variance.data.pointer
// } let meanP = param.mean.data.pointer
// let scaleP = param.scale.data.pointer
// guard let newBias = device.makeBuffer(length: param.inputBias.buffer.length) else { let biasP = param.scale.data.pointer
// fatalError() for i in 0..<count {
// } let invStd = P(1 / (Float32(varianceP[i]) + param.epsilon).squareRoot())
// self.newScale = newScale biasP[i] = biasP[i] - meanP[i] * invStd * scaleP[i]
// self.newBias = newBias scaleP[i] = invStd * scaleP[i]
// }
param.bias.initBuffer(device: device, precision: computePrecision)
param.scale.initBuffer(device: device, precision: computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
...@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable { ...@@ -36,37 +36,16 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
} else { } else {
fatalError() fatalError()
} }
//
// let varianceBuffer : MTLBuffer = param.inputVariance.buffer
//
// var invStd: [Float32] = Array(repeating: 0, count: varianceBuffer.length)
// let varianceContents = varianceBuffer.contents().assumingMemoryBound(to: P.self)
// for i in 0..<(varianceBuffer.length / MemoryLayout<P>.stride) {
// invStd[i] = 1 / (Float32(varianceContents[i]) + param.epsilon).squareRoot()
// }
//
// let newScaleContents = newScale.contents().assumingMemoryBound(to: P.self)
// let newBiasContents = newBias.contents().assumingMemoryBound(to: P.self)
// let scale : MTLBuffer = param.inputScale.buffer
// let scaleContents = scale.contents().assumingMemoryBound(to: P.self)
// let bias : MTLBuffer = param.inputBias.buffer
// let biasContents = bias.contents().assumingMemoryBound(to: P.self)
// let meanContents = param.inputMean.buffer.contents().assumingMemoryBound(to: P.self)
//
// for i in 0..<(newScale.length / MemoryLayout<P>.stride) {
// newScaleContents[i] = P(invStd[i] * Float32(scaleContents[i]))
// newBiasContents[i] = P(Float32(biasContents[i]) - Float32(meanContents[i]) * invStd[i] * Float32(scaleContents[i]))
// }
} }
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil") throw PaddleMobileError.predictError(message: " encoder is nil")
} }
// encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
// encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.output.metalTexture, index: 1)
// encoder.setBuffer(newScale, offset: 0, index: 0) encoder.setBuffer(param.scale.buffer, offset: 0, index: 0)
// encoder.setBuffer(newBias, offset: 0, index: 1) encoder.setBuffer(param.bias.buffer, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
......
...@@ -31,101 +31,111 @@ struct ConcatMetalParam { ...@@ -31,101 +31,111 @@ struct ConcatMetalParam {
} }
class ConcatKernel<P: PrecisionType>: Kernel, Computable{ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
var v = "normal"
func encodeTest(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatTestParam, _ istart: Int, _ iend: Int) { var pm = ConcatMetalParam.init()
let encoder = cmdBuffer.makeComputeCommandEncoder()! func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
var p = ConcatMetalParam.init()
var odim: [Int32] = [1, 1, 1, 1] guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
for i in 0..<param.odim.count { throw PaddleMobileError.predictError(message: " encode is nil")
odim[4-param.odim.count+i] = Int32(param.odim[i])
}
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.odim.count + param.axis)
for i in 0..<istart {
p.offset += Int32(param.dims[i][param.axis])
} }
var vdim: [Int32] = [] let num = param.input.count
for i in 0..<(iend - istart) { for i in 0..<num {
encoder.setTexture(param.input[i+istart], index: i) encoder.setTexture(param.input[i].metalTexture, index: i)
vdim.append(Int32(param.dims[i+istart][Int(param.axis)]))
} }
for i in (iend-istart)..<6 { encoder.setTexture(param.output.metalTexture, index: num)
encoder.setTexture(param.input[0], index: i) if v == "normal" {
vdim.append(0) encoder.setTexture(param.output.metalTexture, index: num + 1)
} }
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5]) encoder.setBytes(&pm, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.setTexture(param.output, index: 6) encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.setTexture(param.output, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding() encoder.endEncoding()
} }
func encode(_ cmdBuffer: MTLCommandBuffer, _ param: ConcatParam<P>, _ istart: Int, _ iend: Int) throws { required init(device: MTLDevice, param: ConcatParam<P>) {
guard let encoder = cmdBuffer.makeComputeCommandEncoder() else { param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
throw PaddleMobileError.predictError(message: " encode is nil") let orank = param.output.tensorDim.cout()
} let num = param.input.count
var p = ConcatMetalParam.init() assert(num <= 6)
let odim = (0..<4).map { Int32(param.output.dim[$0]) } var axis = 4 - param.output.tensorDim.cout() + param.axis
p.odim = (odim[0], odim[1], odim[2], odim[3])
p.axis = Int32(4 - param.output.tensorDim.cout() + param.axis)
for i in 0..<4 { for i in 0..<4 {
if Int32(param.transpose[i]) == p.axis { if param.transpose[i] == axis {
p.axis = Int32(i) axis = i
break break
} }
} }
for i in 0..<istart { pm.axis = Int32(axis)
p.offset += Int32(param.input[i+istart].dim[Int(p.axis)]) pm.odim = (Int32(param.output.dim[0]), Int32(param.output.dim[1]), Int32(param.output.dim[2]), Int32(param.output.dim[3]))
} pm.trans = (Int32(param.output.transpose[0]), Int32(param.output.transpose[1]), Int32(param.output.transpose[2]), Int32(param.output.transpose[3]))
var vdim: [Int32] = [] var vdim: [Int] = [0, 0, 0, 0, 0, 0]
for i in 0..<(iend - istart) { for i in 0..<num {
encoder.setTexture(param.input[i+istart].metalTexture, index: i) vdim[i] = param.input[i].dim[axis]
vdim.append(Int32(param.input[i+istart].dim[Int(p.axis)]))
}
for i in (iend-istart)..<6 {
encoder.setTexture(param.input[0].metalTexture, index: i)
vdim.append(0)
}
p.trans = (Int32(param.transpose[0]), Int32(param.transpose[1]), Int32(param.transpose[2]), Int32(param.transpose[3]))
p.vdim = (vdim[0], vdim[1], vdim[2], vdim[3], vdim[4], vdim[5])
encoder.setTexture(param.output.metalTexture, index: 6)
encoder.setTexture(param.output.metalTexture, index: 7)
encoder.setBytes(&p, length: MemoryLayout<ConcatMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
func compute(commandBuffer: MTLCommandBuffer, param: ConcatParam<P>) throws {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
try self.encode(commandBuffer, param, 6 * i, 6 * (i + 1))
}
if remain > 0 {
try self.encode(commandBuffer, param, 6 * group, param.input.count)
}
}
func test(cmdBuffer: MTLCommandBuffer, param: ConcatTestParam) {
let group = param.input.count / 6
let remain = param.input.count % 6
for i in 0..<group {
self.encodeTest(cmdBuffer, param, 6 * i, 6 * (i + 1))
} }
if remain > 0 { if orank == 4 {
self.encodeTest(cmdBuffer, param, 6 * group, param.input.count) if axis == 1 {
v = "y"
} else if axis == 2 {
v = "x"
} else {
if (param.output.dim[0] == 1) && axis == 3 {
var vz = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vz = false
break
}
}
if vz {
v = "z"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
}
} else if orank == 3 {
if axis == 2 {
v = "y"
} else if axis == 3 {
v = "x"
} else if axis == 1 {
var vz = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vz = false
break
}
}
if vz {
v = "z"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
} else {
if axis == 2 {
v = "y"
} else if axis == 3 {
var vx = true
for i in 0..<num {
if vdim[i] % 4 != 0 {
vx = false
break
}
}
if vx {
v = "x"
for i in 0..<num {
vdim[i] = vdim[i] / 4
}
}
}
} }
} pm.vdim = (Int32(vdim[0]), Int32(vdim[1]), Int32(vdim[2]), Int32(vdim[3]), Int32(vdim[4]), Int32(vdim[5]))
required init(device: MTLDevice, param: ConcatParam<P>) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "concat") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_float")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "concat_half") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_half")
} else { } else {
fatalError() fatalError()
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FlattenMetalParam {
var idim: (Int32, Int32, Int32, Int32)
var itrans: (Int32, Int32, Int32, Int32)
var odim: (Int32, Int32, Int32, Int32)
var otrans: (Int32, Int32, Int32, Int32)
}
class FlattenKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision)
var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
}
let it: [Int32] = param.input.transpose.map { Int32($0) }
var od: [Int32] = [1, 1, 1, 1]
for i in 0..<param.output.tensorDim.cout() {
od[4-param.output.tensorDim.cout()+i] = Int32(param.output.tensorDim[i])
}
let ot: [Int32] = param.output.transpose.map { Int32($0) }
metalParam = FlattenMetalParam.init(
idim: (id[0], id[1], id[2], id[3]),
itrans: (it[0], it[1], it[2], it[3]),
odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3])
)
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
assert(orank == 2)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_half")
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: FlattenParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
}
...@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -49,10 +49,12 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
odim: (od[0], od[1], od[2], od[3]), odim: (od[0], od[1], od[2], od[3]),
otrans: (ot[0], ot[1], ot[2], ot[3]) otrans: (ot[0], ot[1], ot[2], ot[3])
) )
let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_half") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half")
} else { } else {
fatalError() fatalError()
} }
...@@ -69,10 +71,11 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -69,10 +71,11 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
} }
func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws {
print("reshape compute")
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil") throw PaddleMobileError.predictError(message: " encoder is nil")
} }
encoder.setTexture(param.input.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1) encoder.setTexture(param.output.metalTexture, index: 1)
...@@ -81,15 +84,15 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -81,15 +84,15 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
func test(commandBuffer: MTLCommandBuffer, testParam: ReshapeTestParam) { // func test(commandBuffer: MTLCommandBuffer, testParam: ReshapeTestParam) {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { // guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
fatalError() // fatalError()
} // }
encoder.setTexture(testParam.inputTexture, index: 0) // encoder.setTexture(testParam.inputTexture, index: 0)
encoder.setTexture(testParam.outputTexture, index: 1) // encoder.setTexture(testParam.outputTexture, index: 1)
var pm: ReshapeMetalParam = testParam.param // var pm: ReshapeMetalParam = testParam.param
encoder.setBytes(&pm, length: MemoryLayout<ReshapeMetalParam>.size, index: 0) // encoder.setBytes(&pm, length: MemoryLayout<ReshapeMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: testParam.outputTexture) // encoder.dispatch(computePipline: pipline, outTexture: testParam.outputTexture)
encoder.endEncoding() // encoder.endEncoding()
} // }
} }
...@@ -19,11 +19,12 @@ struct ShapeMetalParam { ...@@ -19,11 +19,12 @@ struct ShapeMetalParam {
class ShapeKernel<P: PrecisionType>: Kernel, Computable{ class ShapeKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ShapeParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { print("shape compute")
throw PaddleMobileError.predictError(message: " encode is nil") // guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
} // throw PaddleMobileError.predictError(message: " encode is nil")
encoder.setTexture(param.output.metalTexture, index: 0) // }
encoder.endEncoding() // encoder.setTexture(param.output.metalTexture, index: 0)
// encoder.endEncoding()
} }
required init(device: MTLDevice, param: ShapeParam<P>) { required init(device: MTLDevice, param: ShapeParam<P>) {
......
...@@ -15,23 +15,76 @@ ...@@ -15,23 +15,76 @@
import Foundation import Foundation
struct SplitMetalParam { struct SplitMetalParam {
var idim: (Int32, Int32, Int32, Int32) = (1, 1, 1, 1)
var axis: Int32 = 0
var offset: Int32 = 0
var trans: (Int32, Int32, Int32, Int32) = (0, 1, 2, 3)
var vdim: (Int32, Int32, Int32, Int32) = (0, 0, 0, 0)
} }
class SplitKernel<P: PrecisionType>: Kernel, Computable{ class SplitKernel<P: PrecisionType>: Kernel, Computable{
var smp: SplitMetalParam
func compute(commandBuffer: MTLCommandBuffer, param: SplitParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: SplitParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
} }
encoder.setTexture(param.output.metalTexture, index: 0) encoder.setTexture(param.input.metalTexture, index: 0)
for i in 0..<param.outputList.count {
encoder.setTexture(param.outputList[i].metalTexture, index: i + 1)
}
encoder.setBytes(&smp, length: MemoryLayout<BoxcoderMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture)
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: SplitParam<P>) { required init(device: MTLDevice, param: SplitParam<P>) {
param.output.initTexture(device: device, computePrecision: computePrecision) // param.output.initTexture(device: device, computePrecision: computePrecision)
let num = param.outputList.count
let rank = param.input.tensorDim.cout()
assert(num >= 2 && num <= 4)
for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
}
smp = SplitMetalParam.init()
smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3]))
smp.axis = Int32(param.axis + param.input.dim.cout() - param.input.tensorDim.cout())
for i in 0..<4 {
if param.input.transpose[i] == smp.axis {
smp.axis = Int32(i)
break
}
}
smp.trans = (Int32(param.input.transpose[0]), Int32(param.input.transpose[1]), Int32(param.input.transpose[2]), Int32(param.input.transpose[3]))
var vdim: [Int32] = [0, 0, 0, 0]
for i in 0..<num {
vdim[i] = Int32(param.outputList[i].tensorDim[param.axis])
}
smp.vdim = (vdim[0], vdim[1], vdim[2], vdim[3])
var v = "normal"
if rank == 4 {
if smp.axis == 1 {
v = "y"
} else if smp.axis == 2 {
v = "x"
}
} else if rank == 3 {
if smp.axis == 2 {
v = "y"
} else if smp.axis == 3 {
v = "x"
}
} else if rank == 2 {
if smp.axis == 2 {
v = "y"
}
}
if v == "normal" {
fatalError("split unsupported")
}
if computePrecision == .Float32 { if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)")
} else if computePrecision == .Float16 { } else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_half") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_half")
} else { } else {
fatalError() fatalError()
} }
......
...@@ -15,28 +15,28 @@ ...@@ -15,28 +15,28 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]], kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
const device half4 * newScale [[buffer(0)]], const device float4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]], const device float4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z); const float4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z]; float4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]], kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
const device float4 * newScale [[buffer(0)]], const device half4 * newScale [[buffer(0)]],
const device float4 * newBias [[buffer(1)]], const device half4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
const float4 input = inTexture.read(gid.xy, gid.z); const half4 input = inTexture.read(gid.xy, gid.z);
float4 output = input * newScale[gid.z] + newBias[gid.z]; half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
...@@ -23,7 +23,7 @@ struct bilinear_interp_param { ...@@ -23,7 +23,7 @@ struct bilinear_interp_param {
}; };
kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]], kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]], texture2d_array<float, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]], constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
float4 r; float4 r;
...@@ -47,29 +47,29 @@ kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture ...@@ -47,29 +47,29 @@ kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture
output.write(r, gid.xy, gid.z); output.write(r, gid.xy, gid.z);
} }
kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]], //kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]], // texture2d_array<half, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]], // constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) { // uint3 gid [[thread_position_in_grid]]) {
//
half4 r; // half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) { // if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z); // r = input.read(gid.xy, gid.z);
} else { // } else {
half w = gid.x * pm.ratio_w; // half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h; // half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h; // uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1; // uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0; // half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; // half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0; // if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0; // if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z); // half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z); // half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z); // half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z); // half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3); // r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
} // }
output.write(r, gid.xy, gid.z); // output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z); // output.write(r, gid.xy, gid.z);
} //}
...@@ -15,6 +15,55 @@ ...@@ -15,6 +15,55 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
inline void xyzn2abcd_1(int xyzn[4], int abcd[4]) {
abcd[0] = abcd[1] = abcd[2] = 1;
abcd[3] = xyzn[0] * 4 + xyzn[3];
}
inline void xyzn2abcd_2(int xyzn[4], int abcd[4]) {
abcd[0] = abcd[1] = 1;
abcd[2] = xyzn[1];
abcd[3] = xyzn[0] * 4 + xyzn[3];
}
inline void xyzn2abcd_3(int xyzn[4], int abcd[4]) {
abcd[0] = 1;
abcd[3] = xyzn[0];
abcd[2] = xyzn[1];
abcd[1] = xyzn[2] * 4 + xyzn[3];
}
inline void xyzn2abcd_4(int C, int xyzn[4], int abcd[4]) {
abcd[2] = xyzn[0];
abcd[1] = xyzn[1];
uint t = xyzn[2] * 4 + xyzn[3];
abcd[0] = t / C;
abcd[3] = t % C;
}
inline void abcd2xyzn_1(int abcd[4], int xyzn[4]) {
xyzn[1] = xyzn[2] = 1;
xyzn[0] = abcd[3] / 4;
xyzn[1] = abcd[3] % 4;
}
inline void abcd2xyzn_2(int abcd[4], int xyzn[4]) {
xyzn[2] = 1;
xyzn[1] = abcd[2];
xyzn[0] = abcd[3] / 4;
xyzn[1] = abcd[3] % 4;
}
inline void abcd2xyzn_3(int abcd[4], int xyzn[4]) {
xyzn[0] = abcd[3];
xyzn[1] = abcd[2];
xyzn[2] = abcd[1] / 4;
xyzn[3] = abcd[1] % 4;
}
inline void abcd2xyzn_4(int C, int abcd[4], int xyzn[4]) {
xyzn[0] = abcd[2];
xyzn[1] = abcd[1];
uint t = abcd[0] * C + abcd[3];
xyzn[2] = t / 4;
xyzn[3] = t % 4;
}
inline void xyzn2abcd(int C, int xyzn[4], int abcd[4]) { inline void xyzn2abcd(int C, int xyzn[4], int abcd[4]) {
abcd[2] = xyzn[0]; abcd[2] = xyzn[0];
abcd[1] = xyzn[1]; abcd[1] = xyzn[1];
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[6];
};
kernel void concat(texture2d_array<float, access::read> in0 [[texture(0)]],
texture2d_array<float, access::read> in1 [[texture(1)]],
texture2d_array<float, access::read> in2 [[texture(2)]],
texture2d_array<float, access::read> in3 [[texture(3)]],
texture2d_array<float, access::read> in4 [[texture(4)]],
texture2d_array<float, access::read> in5 [[texture(5)]],
texture2d_array<float, access::read> inx [[texture(6)]],
texture2d_array<float, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
float4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
kernel void concat_half(texture2d_array<half, access::read> in0 [[texture(0)]],
texture2d_array<half, access::read> in1 [[texture(1)]],
texture2d_array<half, access::read> in2 [[texture(2)]],
texture2d_array<half, access::read> in3 [[texture(3)]],
texture2d_array<half, access::read> in4 [[texture(4)]],
texture2d_array<half, access::read> in5 [[texture(5)]],
texture2d_array<half, access::read> inx [[texture(6)]],
texture2d_array<half, access::write> out [[texture(7)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
half4 r;
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
xyzn2abcd(cp.odim[3], xyzn, abcd);
int k = abcd[cp.axis] - cp.offset;
int j = 0;
if (k < 0) {
r[i] = inx.read(gid.xy, gid.z)[i];
} else {
for (; j < 6; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
abcd2xyzn(cp.odim[3], abcd, oxyzn);
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
}
}
}
out.write(r, gid.xy, gid.z);
}
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
#if V == VX
#define VV x
#elif V == VY
#define VV y
#elif V == VZ
#define VV z
#else
#define VV normal
#endif
#if V == VNORMAL
//kernel void FUNC(concat, R, N, normal, P)(array<texture2d_array<P, access::read>, N> in [[texture(0)]],
// texture2d_array<P, access::read> out_x [[texture(N)]],
// texture2d_array<P, access::write> out [[texture(N+1)]],
// constant ConcatParam & pm [[buffer(0)]],
// uint3 gid [[thread_position_in_grid]]) {
//}
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif
texture2d_array<P, access::read> inx [[texture(N)]],
texture2d_array<P, access::write> out [[texture(N+1)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
VECTOR(P, 4) r = inx.read(gid.xy, gid.z);
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
#if R == 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
#else
FUNC_R(xyzn2abcd, R)(xyzn, abcd);
#endif
int k = abcd[cp.axis] - cp.offset;
if (k < 0) continue;
int j = 0;
for (; j < N; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
if (k > cp.vdim[N-1]) {
continue;
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
#if R == 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else
FUNC_R(abcd2xyzn, R)(abcd, oxyzn);
#endif
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#if N >= 3
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 4
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 5
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 6
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
}
}
out.write(r, gid.xy, gid.z);
}
#endif // V == NORMAL
#if V == VX
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int x = gid.x - pm.offset;
if (x < 0) return;
if (x < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
x -= pm.vdim[0];
if (x < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
x -= pm.vdim[1];
if (x < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
x -= pm.vdim[2];
if (x < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
x -= pm.vdim[3];
if (x < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
x -= pm.vdim[4];
if (x < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VX
#if V == VY
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int y = gid.y - pm.offset;
if (y < 0) return;
if (y < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
y -= pm.vdim[0];
if (y < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
y -= pm.vdim[1];
if (y < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
y -= pm.vdim[2];
if (y < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
y -= pm.vdim[3];
if (y < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
y -= pm.vdim[4];
if (y < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VY
#if V == VZ
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int z = gid.z - pm.offset;
if (z < 0) return;
if (z < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
z -= pm.vdim[0];
if (z < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
z -= pm.vdim[1];
if (z < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
z -= pm.vdim[2];
if (z < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
z -= pm.vdim[3];
if (z < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
z -= pm.vdim[4];
if (z < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VZ
#undef VV
#endif // #ifdef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[6];
};
#define VNORMAL 1
#define VX 2
#define VY 3
#define VZ 4
// >> fast mode
// only support concat_{2,3,4}_{2,3,4,5,6}_y_{float,half}
// only support concat_{3,4}_{2,3,4,5,6}_x_{float,half}
// only support concat_{1,2,3,4}_{2,3,4,5,6}_z_{float,half}
// >> normal mode (loop mode)
// ssd-ar: (R=4, N=3, V=z), (R=3, N=2, V=y), (R=2, N=5, V=x), (R=3, N=5, V=x)
// ssd: (R=2, N=6, V=y), (R=3, N=6, V=y)
// genet: (R=4, N=2, V=normal)
// ssd-ar: (R=3, N=5, V=x)
#define V VX
#define R 3
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=2, N=5, V=x)
#define V VX
#define R 2
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=3, N=2, V=y)
#define V VY
#define R 3
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=4, N=3, V=z)
#define V VZ
#define R 4
#define N 3
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=2, N=6, V=y)
#define V VY
#define R 2
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=3, N=6, V=y)
#define V VY
#define R 3
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VNORMAL
#define R 4
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f, r1, r2, p) CONCAT4_(f, r1, r2, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(reshape, RIN, ROUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4];
ReshapeParam lrp = rp;
int oC = lrp.odim[lrp.otrans[3]];
int iC = lrp.idim[lrp.itrans[3]];
int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3];
VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if ROUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd);
#else
FUNC_R(xyzn2abcd, ROUT)(oxyzn, oabcd);
#endif
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd);
if (index < count) {
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn);
#if RIN == 4
abcd2xyzn_4(iC, iabcd, ixyzn);
#else
FUNC_R(abcd2xyzn, RIN)(iabcd, ixyzn);
#endif
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
r[n] = 0;
}
}
outTexture.write(r, gid.xy, gid.z);
}
#endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONRITIONS OF ANY KINR, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
...@@ -24,114 +24,127 @@ struct ReshapeParam { ...@@ -24,114 +24,127 @@ struct ReshapeParam {
int32_t otrans[4]; int32_t otrans[4];
}; };
//kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]], #define P float
// texture2d_array<float, access::write> outTexture [[texture(1)]], #define RIN 4
// constant ReshapeParam &rp [[buffer(0)]], #define ROUT 4
// uint3 gid [[thread_position_in_grid]]) { #include "ReshapeKernel.inc.metal"
// if (gid.x >= outTexture.get_width() || #undef ROUT
// gid.y >= outTexture.get_height() || #define ROUT 3
// gid.z >= outTexture.get_array_size()) return; #include "ReshapeKernel.inc.metal"
// #undef ROUT
// int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4]; #define ROUT 2
// ReshapeParam lrp = rp; #include "ReshapeKernel.inc.metal"
// int oC = lrp.odim[lrp.otrans[3]]; #undef ROUT
// int iC = lrp.idim[lrp.itrans[3]]; #define ROUT 1
// int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; #include "ReshapeKernel.inc.metal"
// float4 r; #undef ROUT
// for (int n = 0; n < 4; n++) { #undef RIN
// oxyzn[3] = n;
//
// //4 (gid.x gid.y, gid.z, 0~4)
// xyzn2abcd(oC, oxyzn, oabcd);
// int tabcd[4];
// invtrans(lrp.otrans, oabcd, tabcd);
// int index = abcd2index(lrp.odim, tabcd);
// if (index < count) {
// int c = index % 4;
//
// int temp0 = index % (inTexture.get_array_size() * 4);
// int slice = temp0 / 4;
//
// int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]);
// int w = temp1 / (inTexture.get_array_size() * 4);
//
// int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]);
//
//// index2abcd(lrp.idim, index, tabcd);
//// abcd2xyzn(iC, tabcd, ixyzn);
// r[n] = inTexture.read(uint2(w, h), slice)[c];
// } else {
// r[n] = 0;
// }
// }
// outTexture.write(r, gid.xy, gid.z);
//}
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]], #undef P
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4]; #define P half
ReshapeParam lrp = rp; #define RIN 4
int oC = lrp.odim[lrp.otrans[3]]; #define ROUT 4
int iC = lrp.idim[lrp.itrans[3]]; #include "ReshapeKernel.inc.metal"
int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; #undef ROUT
float4 r; #define ROUT 3
for (int n = 0; n < 4; n++) { #include "ReshapeKernel.inc.metal"
oxyzn[3] = n; #undef ROUT
xyzn2abcd(oC, oxyzn, oabcd); #define ROUT 2
int tabcd[4]; #include "ReshapeKernel.inc.metal"
invtrans(lrp.otrans, oabcd, tabcd); #undef ROUT
int index = abcd2index(lrp.odim, tabcd); #define ROUT 1
if (index < count) { #include "ReshapeKernel.inc.metal"
index2abcd(lrp.idim, index, tabcd); #undef ROUT
trans(lrp.itrans, tabcd, iabcd); #undef RIN
abcd2xyzn(iC, iabcd, ixyzn);
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
r[n] = 0;
}
}
outTexture.write(r, gid.xy, gid.z);
}
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
kernel void reshape_half(texture2d_array<half, access::read> inTexture [[texture(0)]], #define RIN 2
texture2d_array<half, access::write> outTexture [[texture(1)]], #define ROUT 4
constant ReshapeParam &rp [[buffer(0)]], #include "ReshapeKernel.inc.metal"
uint3 gid [[thread_position_in_grid]]) { #undef ROUT
if (gid.x >= outTexture.get_width() || #define ROUT 3
gid.y >= outTexture.get_height() || #include "ReshapeKernel.inc.metal"
gid.z >= outTexture.get_array_size()) return; #undef ROUT
#define ROUT 2
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4]; #include "ReshapeKernel.inc.metal"
ReshapeParam lrp = rp; #undef ROUT
int oC = lrp.odim[lrp.otrans[3]]; #define ROUT 1
int iC = lrp.idim[lrp.itrans[3]]; #include "ReshapeKernel.inc.metal"
int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; #undef ROUT
half4 r; #undef RIN
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
xyzn2abcd(oC, oxyzn, oabcd);
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd);
if (index < count) {
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
abcd2xyzn(iC, iabcd, ixyzn);
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
r[n] = 0;
}
}
outTexture.write(r, gid.xy, gid.z);
}
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(split, R, N, V, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
#if V == y
int y = gid.y - sp.offset;
if (y < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
y -= sp.vdim[0];
if (y < sp.vdim[1]) {
out2.write(r, uint2(gid.x, y), gid.z);
} else {
#if N >= 3
y -= sp.vdim[1];
if (y < sp.vdim[2]) {
out3.write(r, uint2(gid.x, y), gid.z);
} else {
#if N >= 4
y -= sp.vdim[2];
if (y < sp.vdim[3]) {
out4.write(r, uint2(gid.x, y), gid.z);
}
#endif
}
#endif
}
}
#elif V == x
int x = gid.x;
if (x < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
x -= sp.vdim[0];
if (x < sp.vdim[1]) {
out2.write(r, uint2(x, gid.y), gid.z);
} else {
#if N >= 3
x -= sp.vdim[1];
if (x < sp.vdim[2]) {
out3.write(r, uint2(x, gid.y), gid.z);
} else {
#if N >= 4
x -= sp.vdim[2];
if (x < sp.vdim[3]) {
out4.write(r, uint2(x, gid.y), gid.z);
}
#endif
}
#endif
}
}
#else
#endif
}
#endif
...@@ -13,18 +13,60 @@ ...@@ -13,18 +13,60 @@
limitations under the License. */ limitations under the License. */
#include <metal_stdlib> #include <metal_stdlib>
#include "Common.metal"
using namespace metal; using namespace metal;
kernel void split(texture2d_array<float, access::write> output[[texture(0)]], struct SplitParam {
uint3 gid [[thread_position_in_grid]]) { int32_t idim[4];
float4 r; int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[4];
};
// only support split_{2, 3, 4}_{2, 3, 4}_y_{float, half}
// only support split_{3, 4}_{2, 3, 4}_x_{float, half}
#define V y
// for R in 2..4
#define R 3
// for N in 2..4
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
// end for N
#undef R
// end for R
#undef V
#define V x
// for R in 3..4
#define R 3
// for N in 2..4
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
// end for N
output.write(r, gid.xy, gid.z); #undef R
} // end for R
#undef V
kernel void split_half(texture2d_array<half, access::write> output[[texture(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
output.write(half4(r), gid.xy, gid.z);
}
...@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam { ...@@ -43,15 +43,12 @@ class ReshapeParam<P: PrecisionType>: OpParam {
} }
output.padToFourDim = Dim.init(inDim: dim) output.padToFourDim = Dim.init(inDim: dim)
output.dim = output.padToFourDim output.dim = output.padToFourDim
// inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
} }
let input: Texture<P> let input: Texture<P>
let shape: [Int32] let shape: [Int32]
// let inplace: Bool
var output: Texture<P> var output: Texture<P>
} }
......
...@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam { ...@@ -18,17 +18,19 @@ class ShapeParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
output = try ShapeParam.output(outputs: opDesc.outputs, from: inScope) input = try ShapeParam.input(inputs: opDesc.inputs, from: inScope)
output = try ShapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error { } catch let error {
throw error throw error
} }
} }
var output: Texture<P> var output: Texture<P>
let input: Texture<P>
} }
class ShapeOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{ class ShapeOp<P: PrecisionType>: Operator<ShapeKernel<P>, ShapeParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = SplitOp<P> typealias OpType = ShapeOp<P>
func inferShape() { func inferShape() {
// para.output.dim = para.input.dim // para.output.dim = para.input.dim
......
...@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam { ...@@ -18,13 +18,32 @@ class SplitParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
// output = try SplitParam.output(outputs: opDesc.outputs, from: inScope) input = try SplitParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try SplitParam.outputOut(outputs: opDesc.outputs, from: inScope) output = Texture<P>.init(device: input.metalTexture!.device, inDim: input.dim)
axis = try SplitParam.getAttr(key: "axis", attrs: opDesc.attrs)
sections = try SplitParam.getAttr(key: "sections", attrs: opDesc.attrs)
if axis < 0 {
axis = input.tensorDim.cout() + axis
}
guard let outlist = opDesc.outputs["Out"] else {
fatalError()
}
for out in outlist {
guard let variant = inScope[out], let v = variant as? Texture<P> else {
fatalError()
}
outputList.append(v)
sections.append(Int32(v.tensorDim.dims[axis]))
}
} catch let error { } catch let error {
throw error throw error
} }
} }
var axis: Int
let input: Texture<P>
var output: Texture<P> var output: Texture<P>
var outputList: [Texture<P>] = []
var sections: [Int32] = []
} }
class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{ class SplitOp<P: PrecisionType>: Operator<SplitKernel<P>, SplitParam<P>>, Runable, Creator, InferShaperable{
......
...@@ -16,7 +16,7 @@ import Foundation ...@@ -16,7 +16,7 @@ import Foundation
class ScaleKernel: CusomKernel { class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape) { init(device: MTLDevice, shape: Shape) {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
} }
} }
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import Foundation import Foundation
let testTo = 3 let testTo = 114
var isTest = false var isTest = false
let computePrecision: ComputePrecision = .Float16 let computePrecision: ComputePrecision = .Float32
public class ResultHolder { public class ResultHolder {
public let dim: [Int] public let dim: [Int]
...@@ -101,7 +101,7 @@ public class Executor<P: PrecisionType> { ...@@ -101,7 +101,7 @@ public class Executor<P: PrecisionType> {
let inputTexture = InputTexture.init(inMTLTexture: resInput, inExpectDim: Dim.init(inDim: dim)) let inputTexture = InputTexture.init(inMTLTexture: resInput, inExpectDim: Dim.init(inDim: dim))
program.scope.setInput(input: inputTexture) program.scope.setInput(input: inputTexture)
//(ops.count - except) //(ops.count - except)
for i in 0..<ops.count { for i in 0..<testTo {
let op = ops[i] let op = ops[i]
do { do {
try op.run(device: device, buffer: buffer) try op.run(device: device, buffer: buffer)
...@@ -112,35 +112,35 @@ public class Executor<P: PrecisionType> { ...@@ -112,35 +112,35 @@ public class Executor<P: PrecisionType> {
var outputTextures: [String : [Variant]]? var outputTextures: [String : [Variant]]?
if except > 0 { if except > 0 {
outputTextures = ops[ops.count - except].inputVariant() outputTextures = ops[testTo-1].inputVariant()
} }
buffer.addCompletedHandler { [weak self] (commandbuffer) in buffer.addCompletedHandler { [weak self] (commandbuffer) in
// let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2])) let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2]))
//// print(inputArr.strideArray()) print(inputArr.strideArray())
// print(dim) // print(dim)
// writeToLibrary(fileName: "test_image_ssd_ar", array: inputArr) // writeToLibrary(fileName: "test_image_ssd_ar", array: inputArr)
//
// print("write to library done") // print("write to library done")
// return // return
// print(inputArr) // print(inputArr)
//
// let stridableInput: [(index: Int, value: Float)] = input.stridableFloatArray() // let stridableInput: [(index: Int, value: Float)] = input.stridableFloatArray()
// print(stridableInput) // print(stridableInput)
//
// let _: Flo? = input.logDesc(header: "input: ", stridable: true) // let _: Flo? = input.logDesc(header: "input: ", stridable: true)
// for i in 0..<self.ops.count { for i in 0..<testTo {
// let op = self.ops[i] let op = self!.ops[i]
// print(" 第 \(i) 个 op: ") print(" 第 \(i) 个 op: ")
// op.delogOutput() op.delogOutput()
// } }
// return; // return;
// self.ops[testTo - 2].delogOutput() // self!.ops[testTo - 2].delogOutput()
// self.ops[testTo - 1].delogOutput() // self!.ops[testTo - 1].delogOutput()
// self.ops[60].delogOutput() // self!.ops[60].delogOutput()
// return // return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册