“24509f4af942bb250564756ad636691c7921e1df”上不存在“paddle/legacy/gserver/layers/CudnnPoolLayer.cpp”
提交 6b394a05 编写于 作者: L liuruilong

format files and adjust architecture

上级 67203536
...@@ -96,3 +96,6 @@ metal/paddle-mobile/paddle-mobile/CPU/libpaddle-mobile.a ...@@ -96,3 +96,6 @@ metal/paddle-mobile/paddle-mobile/CPU/libpaddle-mobile.a
metal/paddle-mobile-demo/paddle-mobile-demo/images metal/paddle-mobile-demo/paddle-mobile-demo/images
metal/paddle-mobile-demo/paddle-mobile-demo/models metal/paddle-mobile-demo/paddle-mobile-demo/models
metal/paddle-mobile-demo/paddle-mobile-demo/Resources
metal/paddle-mobile-demo/paddle-mobile-demo/Resources/images
metal/paddle-mobile-demo/paddle-mobile-demo/Resources/models
...@@ -10,13 +10,26 @@ ...@@ -10,13 +10,26 @@
30D0ED21F392CFA3885B1002 /* Pods_paddle_mobile_demo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */; }; 30D0ED21F392CFA3885B1002 /* Pods_paddle_mobile_demo.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */; };
C2CBB49021B778EA0020DC6C /* libc++.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = FC4FD97B2140EE250073E130 /* libc++.tbd */; }; C2CBB49021B778EA0020DC6C /* libc++.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = FC4FD97B2140EE250073E130 /* libc++.tbd */; };
C2E67E5E21524E460013F575 /* LoadPointerViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = C2E67E5D21524E460013F575 /* LoadPointerViewController.m */; }; C2E67E5E21524E460013F575 /* LoadPointerViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = C2E67E5D21524E460013F575 /* LoadPointerViewController.m */; };
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC013927210204A3008100E3 /* PreProcessKernel.metal */; };
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B8120E11C550081E9F8 /* AppDelegate.swift */; }; FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B8120E11C550081E9F8 /* AppDelegate.swift */; };
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B8320E11C550081E9F8 /* ViewController.swift */; }; FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039B8320E11C550081E9F8 /* ViewController.swift */; };
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; }; FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; };
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; }; FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; }; FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC203FB221CBFDBA00B37166 /* test.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC203FA921CBFDBA00B37166 /* test.jpg */; }; FC203FB221CBFDBA00B37166 /* test.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC203FA921CBFDBA00B37166 /* test.jpg */; };
FC2BFCBC21DF0A8600C262B2 /* 00001.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC2BFCBB21DF0A8600C262B2 /* 00001.jpg */; };
FC2BFCBE21DF15D900C262B2 /* 123.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC2BFCBD21DF15D900C262B2 /* 123.jpg */; };
FC2BFCC021DF279900C262B2 /* classify-img-output.png in Resources */ = {isa = PBXBuildFile; fileRef = FC2BFCBF21DF279900C262B2 /* classify-img-output.png */; };
FC2BFD3021DF3FEA00C262B2 /* MobilenetSSD_AR.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2A21DF3FE900C262B2 /* MobilenetSSD_AR.swift */; };
FC2BFD3121DF3FEA00C262B2 /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2B21DF3FE900C262B2 /* Genet.swift */; };
FC2BFD3221DF3FEA00C262B2 /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2C21DF3FE900C262B2 /* MobileNetSSD.swift */; };
FC2BFD3321DF3FEA00C262B2 /* YoloNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2D21DF3FE900C262B2 /* YoloNet.swift */; };
FC2BFD3421DF3FEA00C262B2 /* MobileNetCombined.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2E21DF3FEA00C262B2 /* MobileNetCombined.swift */; };
FC2BFD3521DF3FEA00C262B2 /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD2F21DF3FEA00C262B2 /* MobileNet.swift */; };
FC2BFD3821DF46DE00C262B2 /* OCDemoViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD3721DF46DE00C262B2 /* OCDemoViewController.m */; };
FC2BFD3C21DF480400C262B2 /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD3B21DF480400C262B2 /* CPUCompute.mm */; };
FC2BFD3E21DF5CE800C262B2 /* PreProcessKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD3D21DF5CE800C262B2 /* PreProcessKernel.metal */; };
FC2BFD4321DF5E1E00C262B2 /* PaddleMobileGPU.m in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD4021DF5E1E00C262B2 /* PaddleMobileGPU.m */; };
FC2BFD4421DF5E1E00C262B2 /* SuperResolutionNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD4221DF5E1E00C262B2 /* SuperResolutionNet.swift */; };
FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */ = {isa = PBXBuildFile; fileRef = FC5E03B121DCE8D90016C137 /* mingren_input_data */; }; FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */ = {isa = PBXBuildFile; fileRef = FC5E03B121DCE8D90016C137 /* mingren_input_data */; };
FC704C1921D2375300F98BAB /* super_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1721D2375300F98BAB /* super_params */; }; FC704C1921D2375300F98BAB /* super_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1721D2375300F98BAB /* super_params */; };
FC704C1A21D2375300F98BAB /* super_model in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1821D2375300F98BAB /* super_model */; }; FC704C1A21D2375300F98BAB /* super_model in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1821D2375300F98BAB /* super_model */; };
...@@ -31,7 +44,6 @@ ...@@ -31,7 +44,6 @@
FC9797C321D608E000F2FD90 /* mobilenet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC9797C121D608DF00F2FD90 /* mobilenet_params */; }; FC9797C321D608E000F2FD90 /* mobilenet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC9797C121D608DF00F2FD90 /* mobilenet_params */; };
FC9797C721D609FB00F2FD90 /* synset.txt in Resources */ = {isa = PBXBuildFile; fileRef = FC9797C621D609FB00F2FD90 /* synset.txt */; }; FC9797C721D609FB00F2FD90 /* synset.txt in Resources */ = {isa = PBXBuildFile; fileRef = FC9797C621D609FB00F2FD90 /* synset.txt */; };
FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC9797CE21D6506F00F2FD90 /* mingren.jpg */; }; FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC9797CE21D6506F00F2FD90 /* mingren.jpg */; };
FC9797D121D6616600F2FD90 /* BufferToTexture.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC9797D021D6616600F2FD90 /* BufferToTexture.metal */; };
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCBCCC542122EF5400D94F7E /* MetalHelper.swift */; }; FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCBCCC542122EF5400D94F7E /* MetalHelper.swift */; };
FCCED60521D7646E00BE8D5F /* test_image_super in Resources */ = {isa = PBXBuildFile; fileRef = FCCED60421D7646E00BE8D5F /* test_image_super */; }; FCCED60521D7646E00BE8D5F /* test_image_super in Resources */ = {isa = PBXBuildFile; fileRef = FCCED60421D7646E00BE8D5F /* test_image_super */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; }; FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
...@@ -59,7 +71,6 @@ ...@@ -59,7 +71,6 @@
878829884E1A14D7044721D5 /* Pods-paddle-mobile-demo.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-demo.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-demo/Pods-paddle-mobile-demo.debug.xcconfig"; sourceTree = "<group>"; }; 878829884E1A14D7044721D5 /* Pods-paddle-mobile-demo.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-demo.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-demo/Pods-paddle-mobile-demo.debug.xcconfig"; sourceTree = "<group>"; };
C2E67E5C21524E460013F575 /* LoadPointerViewController.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LoadPointerViewController.h; sourceTree = "<group>"; }; C2E67E5C21524E460013F575 /* LoadPointerViewController.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LoadPointerViewController.h; sourceTree = "<group>"; };
C2E67E5D21524E460013F575 /* LoadPointerViewController.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = LoadPointerViewController.m; sourceTree = "<group>"; }; C2E67E5D21524E460013F575 /* LoadPointerViewController.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = LoadPointerViewController.m; sourceTree = "<group>"; };
FC013927210204A3008100E3 /* PreProcessKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = PreProcessKernel.metal; sourceTree = "<group>"; };
FC039B7E20E11C550081E9F8 /* paddle-mobile-demo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "paddle-mobile-demo.app"; sourceTree = BUILT_PRODUCTS_DIR; }; FC039B7E20E11C550081E9F8 /* paddle-mobile-demo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "paddle-mobile-demo.app"; sourceTree = BUILT_PRODUCTS_DIR; };
FC039B8120E11C550081E9F8 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; }; FC039B8120E11C550081E9F8 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
FC039B8320E11C550081E9F8 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = "<group>"; }; FC039B8320E11C550081E9F8 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = "<group>"; };
...@@ -69,6 +80,23 @@ ...@@ -69,6 +80,23 @@
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; }; FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC203FA921CBFDBA00B37166 /* test.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = test.jpg; sourceTree = "<group>"; }; FC203FA921CBFDBA00B37166 /* test.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = test.jpg; sourceTree = "<group>"; };
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; }; FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; };
FC2BFCBB21DF0A8600C262B2 /* 00001.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = 00001.jpg; sourceTree = "<group>"; };
FC2BFCBD21DF15D900C262B2 /* 123.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = 123.jpg; sourceTree = "<group>"; };
FC2BFCBF21DF279900C262B2 /* classify-img-output.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "classify-img-output.png"; sourceTree = "<group>"; };
FC2BFD2A21DF3FE900C262B2 /* MobilenetSSD_AR.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobilenetSSD_AR.swift; sourceTree = "<group>"; };
FC2BFD2B21DF3FE900C262B2 /* Genet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC2BFD2C21DF3FE900C262B2 /* MobileNetSSD.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; };
FC2BFD2D21DF3FE900C262B2 /* YoloNet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = YoloNet.swift; sourceTree = "<group>"; };
FC2BFD2E21DF3FEA00C262B2 /* MobileNetCombined.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileNetCombined.swift; sourceTree = "<group>"; };
FC2BFD2F21DF3FEA00C262B2 /* MobileNet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC2BFD3621DF46DE00C262B2 /* OCDemoViewController.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = OCDemoViewController.h; sourceTree = "<group>"; };
FC2BFD3721DF46DE00C262B2 /* OCDemoViewController.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = OCDemoViewController.m; sourceTree = "<group>"; };
FC2BFD3A21DF480300C262B2 /* CPUCompute.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; };
FC2BFD3B21DF480400C262B2 /* CPUCompute.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = "<group>"; };
FC2BFD3D21DF5CE800C262B2 /* PreProcessKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = PreProcessKernel.metal; sourceTree = "<group>"; };
FC2BFD4021DF5E1E00C262B2 /* PaddleMobileGPU.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = PaddleMobileGPU.m; sourceTree = "<group>"; };
FC2BFD4121DF5E1E00C262B2 /* PaddleMobileGPU.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PaddleMobileGPU.h; sourceTree = "<group>"; };
FC2BFD4221DF5E1E00C262B2 /* SuperResolutionNet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SuperResolutionNet.swift; sourceTree = "<group>"; };
FC4FD97B2140EE250073E130 /* libc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libc++.tbd"; path = "usr/lib/libc++.tbd"; sourceTree = SDKROOT; }; FC4FD97B2140EE250073E130 /* libc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libc++.tbd"; path = "usr/lib/libc++.tbd"; sourceTree = SDKROOT; };
FC5E03B121DCE8D90016C137 /* mingren_input_data */ = {isa = PBXFileReference; lastKnownFileType = file; path = mingren_input_data; sourceTree = "<group>"; }; FC5E03B121DCE8D90016C137 /* mingren_input_data */ = {isa = PBXFileReference; lastKnownFileType = file; path = mingren_input_data; sourceTree = "<group>"; };
FC704C1721D2375300F98BAB /* super_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_params; sourceTree = "<group>"; }; FC704C1721D2375300F98BAB /* super_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_params; sourceTree = "<group>"; };
...@@ -84,7 +112,6 @@ ...@@ -84,7 +112,6 @@
FC9797C121D608DF00F2FD90 /* mobilenet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_params; sourceTree = "<group>"; }; FC9797C121D608DF00F2FD90 /* mobilenet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_params; sourceTree = "<group>"; };
FC9797C621D609FB00F2FD90 /* synset.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = synset.txt; sourceTree = "<group>"; }; FC9797C621D609FB00F2FD90 /* synset.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = synset.txt; sourceTree = "<group>"; };
FC9797CE21D6506F00F2FD90 /* mingren.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = mingren.jpg; sourceTree = "<group>"; }; FC9797CE21D6506F00F2FD90 /* mingren.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = mingren.jpg; sourceTree = "<group>"; };
FC9797D021D6616600F2FD90 /* BufferToTexture.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = BufferToTexture.metal; sourceTree = "<group>"; };
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalHelper.swift; sourceTree = "<group>"; }; FCBCCC542122EF5400D94F7E /* MetalHelper.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalHelper.swift; sourceTree = "<group>"; };
FCCED60421D7646E00BE8D5F /* test_image_super */ = {isa = PBXFileReference; lastKnownFileType = file; path = test_image_super; sourceTree = "<group>"; }; FCCED60421D7646E00BE8D5F /* test_image_super */ = {isa = PBXFileReference; lastKnownFileType = file; path = test_image_super; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
...@@ -145,8 +172,11 @@ ...@@ -145,8 +172,11 @@
FC039B8020E11C550081E9F8 /* paddle-mobile-demo */ = { FC039B8020E11C550081E9F8 /* paddle-mobile-demo */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC203FA821CBFDBA00B37166 /* images */, FC2BFD4F21DF892500C262B2 /* Resources */,
FC203FAA21CBFDBA00B37166 /* models */, FCBCCC542122EF5400D94F7E /* MetalHelper.swift */,
FC2BFD3F21DF5DDF00C262B2 /* OCInterface */,
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */,
FC2BFD3921DF46F000C262B2 /* OCDemo */,
FC803BCA214D27920094B8E5 /* VideoCapture */, FC803BCA214D27920094B8E5 /* VideoCapture */,
FC8CFED2213519540094D569 /* Net */, FC8CFED2213519540094D569 /* Net */,
FC039B8120E11C550081E9F8 /* AppDelegate.swift */, FC039B8120E11C550081E9F8 /* AppDelegate.swift */,
...@@ -155,10 +185,7 @@ ...@@ -155,10 +185,7 @@
FC039B8820E11C560081E9F8 /* Assets.xcassets */, FC039B8820E11C560081E9F8 /* Assets.xcassets */,
FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */, FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */,
FC039B8D20E11C560081E9F8 /* Info.plist */, FC039B8D20E11C560081E9F8 /* Info.plist */,
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */,
FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */, FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */,
C2E67E5C21524E460013F575 /* LoadPointerViewController.h */,
C2E67E5D21524E460013F575 /* LoadPointerViewController.m */,
); );
path = "paddle-mobile-demo"; path = "paddle-mobile-demo";
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -166,14 +193,16 @@ ...@@ -166,14 +193,16 @@
FC203FA821CBFDBA00B37166 /* images */ = { FC203FA821CBFDBA00B37166 /* images */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC2BFCBF21DF279900C262B2 /* classify-img-output.png */,
FC2BFCBD21DF15D900C262B2 /* 123.jpg */,
FC2BFCBB21DF0A8600C262B2 /* 00001.jpg */,
FC5E03B121DCE8D90016C137 /* mingren_input_data */, FC5E03B121DCE8D90016C137 /* mingren_input_data */,
FCCED60421D7646E00BE8D5F /* test_image_super */, FCCED60421D7646E00BE8D5F /* test_image_super */,
FC9797CE21D6506F00F2FD90 /* mingren.jpg */, FC9797CE21D6506F00F2FD90 /* mingren.jpg */,
FC9797BD21D6045B00F2FD90 /* banana.jpeg */, FC9797BD21D6045B00F2FD90 /* banana.jpeg */,
FC203FA921CBFDBA00B37166 /* test.jpg */, FC203FA921CBFDBA00B37166 /* test.jpg */,
); );
name = images; path = images;
path = ../../images;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FC203FAA21CBFDBA00B37166 /* models */ = { FC203FAA21CBFDBA00B37166 /* models */ = {
...@@ -183,8 +212,37 @@ ...@@ -183,8 +212,37 @@
FC704C1B21D237FC00F98BAB /* vision_model */, FC704C1B21D237FC00F98BAB /* vision_model */,
FC704C1621D2375300F98BAB /* superresoltion */, FC704C1621D2375300F98BAB /* superresoltion */,
); );
name = models; path = models;
path = ../../models; sourceTree = "<group>";
};
FC2BFD3921DF46F000C262B2 /* OCDemo */ = {
isa = PBXGroup;
children = (
C2E67E5C21524E460013F575 /* LoadPointerViewController.h */,
C2E67E5D21524E460013F575 /* LoadPointerViewController.m */,
FC2BFD3621DF46DE00C262B2 /* OCDemoViewController.h */,
FC2BFD3721DF46DE00C262B2 /* OCDemoViewController.m */,
);
path = OCDemo;
sourceTree = "<group>";
};
FC2BFD3F21DF5DDF00C262B2 /* OCInterface */ = {
isa = PBXGroup;
children = (
FC2BFD4121DF5E1E00C262B2 /* PaddleMobileGPU.h */,
FC2BFD4021DF5E1E00C262B2 /* PaddleMobileGPU.m */,
FC2BFD4221DF5E1E00C262B2 /* SuperResolutionNet.swift */,
);
path = OCInterface;
sourceTree = "<group>";
};
FC2BFD4F21DF892500C262B2 /* Resources */ = {
isa = PBXGroup;
children = (
FC203FA821CBFDBA00B37166 /* images */,
FC203FAA21CBFDBA00B37166 /* models */,
);
path = Resources;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FC704C1621D2375300F98BAB /* superresoltion */ = { FC704C1621D2375300F98BAB /* superresoltion */ = {
...@@ -235,9 +293,15 @@ ...@@ -235,9 +293,15 @@
FC8CFED2213519540094D569 /* Net */ = { FC8CFED2213519540094D569 /* Net */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC013927210204A3008100E3 /* PreProcessKernel.metal */, FC2BFD3A21DF480300C262B2 /* CPUCompute.h */,
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */, FC2BFD3B21DF480400C262B2 /* CPUCompute.mm */,
FC9797D021D6616600F2FD90 /* BufferToTexture.metal */, FC2BFD3D21DF5CE800C262B2 /* PreProcessKernel.metal */,
FC2BFD2B21DF3FE900C262B2 /* Genet.swift */,
FC2BFD2F21DF3FEA00C262B2 /* MobileNet.swift */,
FC2BFD2E21DF3FEA00C262B2 /* MobileNetCombined.swift */,
FC2BFD2A21DF3FE900C262B2 /* MobilenetSSD_AR.swift */,
FC2BFD2C21DF3FE900C262B2 /* MobileNetSSD.swift */,
FC2BFD2D21DF3FE900C262B2 /* YoloNet.swift */,
); );
path = Net; path = Net;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -319,16 +383,19 @@ ...@@ -319,16 +383,19 @@
FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */, FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */,
FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */, FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */,
FC704C1921D2375300F98BAB /* super_params in Resources */, FC704C1921D2375300F98BAB /* super_params in Resources */,
FC2BFCBE21DF15D900C262B2 /* 123.jpg in Resources */,
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */,
FC9797C721D609FB00F2FD90 /* synset.txt in Resources */, FC9797C721D609FB00F2FD90 /* synset.txt in Resources */,
FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */, FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */,
FC704C1A21D2375300F98BAB /* super_model in Resources */, FC704C1A21D2375300F98BAB /* super_model in Resources */,
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */, FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */,
FC9797C221D608E000F2FD90 /* mobilenet_model in Resources */, FC9797C221D608E000F2FD90 /* mobilenet_model in Resources */,
FC2BFCC021DF279900C262B2 /* classify-img-output.png in Resources */,
FC203FB221CBFDBA00B37166 /* test.jpg in Resources */, FC203FB221CBFDBA00B37166 /* test.jpg in Resources */,
FC704C2321D237FC00F98BAB /* combined_mobilenet_model in Resources */, FC704C2321D237FC00F98BAB /* combined_mobilenet_model in Resources */,
FC9797C321D608E000F2FD90 /* mobilenet_params in Resources */, FC9797C321D608E000F2FD90 /* mobilenet_params in Resources */,
FC704C2421D237FC00F98BAB /* yolo_params in Resources */, FC704C2421D237FC00F98BAB /* yolo_params in Resources */,
FC2BFCBC21DF0A8600C262B2 /* 00001.jpg in Resources */,
FC9797BE21D6045B00F2FD90 /* banana.jpeg in Resources */, FC9797BE21D6045B00F2FD90 /* banana.jpeg in Resources */,
FC704C2521D237FC00F98BAB /* yolo_model in Resources */, FC704C2521D237FC00F98BAB /* yolo_model in Resources */,
); );
...@@ -380,15 +447,24 @@ ...@@ -380,15 +447,24 @@
isa = PBXSourcesBuildPhase; isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC2BFD3221DF3FEA00C262B2 /* MobileNetSSD.swift in Sources */,
FC2BFD3C21DF480400C262B2 /* CPUCompute.mm in Sources */,
FC2BFD4321DF5E1E00C262B2 /* PaddleMobileGPU.m in Sources */,
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC803BCE214D27930094B8E5 /* VideoCapture.swift in Sources */, FC803BCE214D27930094B8E5 /* VideoCapture.swift in Sources */,
FC013928210204A3008100E3 /* PreProcessKernel.metal in Sources */,
FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */, FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */,
FC2BFD3021DF3FEA00C262B2 /* MobilenetSSD_AR.swift in Sources */,
FC2BFD3321DF3FEA00C262B2 /* YoloNet.swift in Sources */,
FC2BFD3421DF3FEA00C262B2 /* MobileNetCombined.swift in Sources */,
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */, FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */,
FC803BCD214D27930094B8E5 /* FPSCounter.swift in Sources */, FC803BCD214D27930094B8E5 /* FPSCounter.swift in Sources */,
FC2BFD3521DF3FEA00C262B2 /* MobileNet.swift in Sources */,
C2E67E5E21524E460013F575 /* LoadPointerViewController.m in Sources */, C2E67E5E21524E460013F575 /* LoadPointerViewController.m in Sources */,
FC2BFD3121DF3FEA00C262B2 /* Genet.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */, FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
FC9797D121D6616600F2FD90 /* BufferToTexture.metal in Sources */, FC2BFD4421DF5E1E00C262B2 /* SuperResolutionNet.swift in Sources */,
FC2BFD3E21DF5CE800C262B2 /* PreProcessKernel.metal in Sources */,
FC2BFD3821DF46DE00C262B2 /* OCDemoViewController.m in Sources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
}; };
......
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14113" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" useSafeAreas="YES" colorMatched="YES" initialViewController="BYZ-38-t0r"> <document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14460.31" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" useSafeAreas="YES" colorMatched="YES" initialViewController="BYZ-38-t0r">
<device id="retina4_7" orientation="portrait"> <device id="retina4_7" orientation="portrait">
<adaptation id="fullscreen"/> <adaptation id="fullscreen"/>
</device> </device>
<dependencies> <dependencies>
<deployment identifier="iOS"/> <deployment identifier="iOS"/>
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14088"/> <plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14460.20"/>
<capability name="Aspect ratio constraints" minToolsVersion="5.1"/>
<capability name="Safe area layout guides" minToolsVersion="9.0"/> <capability name="Safe area layout guides" minToolsVersion="9.0"/>
<capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/> <capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
</dependencies> </dependencies>
...@@ -20,7 +19,7 @@ ...@@ -20,7 +19,7 @@
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/> <autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
<subviews> <subviews>
<button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="TQt-X9-PdF"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="TQt-X9-PdF">
<rect key="frame" x="164" y="318" width="46" height="30"/> <rect key="frame" x="164.5" y="318.5" width="46" height="30"/>
<state key="normal" title="Button"/> <state key="normal" title="Button"/>
<connections> <connections>
<action selector="predictAct:" destination="Vwd-lt-764" eventType="touchUpInside" id="d4z-Cv-6jY"/> <action selector="predictAct:" destination="Vwd-lt-764" eventType="touchUpInside" id="d4z-Cv-6jY"/>
...@@ -60,7 +59,7 @@ ...@@ -60,7 +59,7 @@
<nil key="highlightedColor"/> <nil key="highlightedColor"/>
</label> </label>
<pickerView contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="DlO-dk-RMr"> <pickerView contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="DlO-dk-RMr">
<rect key="frame" x="55" y="510.5" width="320" height="80"/> <rect key="frame" x="55" y="510" width="320" height="80"/>
<constraints> <constraints>
<constraint firstAttribute="height" constant="80" id="Sbi-05-Mwd"/> <constraint firstAttribute="height" constant="80" id="Sbi-05-Mwd"/>
</constraints> </constraints>
...@@ -83,6 +82,9 @@ ...@@ -83,6 +82,9 @@
<button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="wUL-9N-u1V"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="wUL-9N-u1V">
<rect key="frame" x="16" y="597" width="63.5" height="30"/> <rect key="frame" x="16" y="597" width="63.5" height="30"/>
<color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<constraints>
<constraint firstAttribute="width" secondItem="wUL-9N-u1V" secondAttribute="height" multiplier="21:10" id="cp7-bd-CvU"/>
</constraints>
<state key="normal" title="Image"> <state key="normal" title="Image">
<color key="titleColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="titleColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
</state> </state>
......
//
// LoadPointerViewController.h
// paddle-mobile-demo
//
// Created by Xiao,Haichun on 2018/9/19.
// Copyright © 2018年 orange. All rights reserved.
//
#import <UIKit/UIKit.h>
@interface LoadPointerViewController : UIViewController
@end
...@@ -27,7 +27,4 @@ public class MetalHelper { ...@@ -27,7 +27,4 @@ public class MetalHelper {
queue = device.makeCommandQueue()! queue = device.makeCommandQueue()!
textureLoader = MTKTextureLoader.init(device: device) textureLoader = MTKTextureLoader.init(device: device)
} }
} }
...@@ -13,41 +13,35 @@ ...@@ -13,41 +13,35 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import paddle_mobile
public class Genet: Net { public class Genet: Net {
@objc public override init(device: MTLDevice) { @objc public override init(device: MTLDevice) {
super.init(device: device) super.init(device: device)
means = [128.0, 128.0, 128.0]
scale = 0.017
except = 0
modelPath = Bundle.main.path(forResource: "genet_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "genet_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "genet_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "genet_params", ofType: nil) ?! "para null"
modelDir = ""
preprocessKernel = GenetPreProccess.init(device: device) preprocessKernel = GenetPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 128, 128, 3]) inputDim = Dim.init(inDim: [1, 128, 128, 3])
} }
@objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) { @objc override public init(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize) super.init(device: device,
means = [128.0, 128.0, 128.0] paramPointer: paramPointer,
scale = 0.017 paramSize: paramSize,
except = 0 modePointer: modePointer,
modelPath = "" modelSize: modelSize)
paramPath = ""
modelDir = ""
preprocessKernel = GenetPreProccess.init(device: device) preprocessKernel = GenetPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 128, 128, 3]) inputDim = Dim.init(inDim: [1, 128, 128, 3])
} }
class GenetPreProccess: CusomKernel { class GenetPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = Shape.init(inWidth: 128, inHeight: 128, inChannel: 3) let s = Shape.init(inWidth: 128, inHeight: 128, inChannel: 3)
super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, metalLoadModel: .LoadMetalInDefaultLib, metalLibPath: nil)
} }
} }
override public func resultStr(res: ResultHolder) -> String { override public func resultStr(res: ResultHolder) -> String {
// fatalError()
return " \(res.result[0]) ... " return " \(res.result[0]) ... "
} }
......
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import paddle_mobile
public class MobileNet: Net{ public class MobileNet: Net{
class MobilenetPreProccess: CusomKernel { class MobilenetPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = Shape.init(inWidth: 224, inHeight: 224, inChannel: 3) let s = Shape.init(inWidth: 224, inHeight: 224, inChannel: 3)
super.init(device: device, inFunctionName: "mobilenet_preprocess", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilenet_preprocess", outputDim: s, metalLoadModel: .LoadMetalInDefaultLib, metalLibPath: nil)
} }
} }
...@@ -53,14 +54,13 @@ public class MobileNet: Net{ ...@@ -53,14 +54,13 @@ public class MobileNet: Net{
override public init(device: MTLDevice) { override public init(device: MTLDevice) {
super.init(device: device) super.init(device: device)
means = [123.68, 116.78, 103.94]
scale = 0.017
except = 0 except = 0
modelPath = Bundle.main.path(forResource: "mobilenet_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "mobilenet_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "mobilenet_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "mobilenet_params", ofType: nil) ?! "para null"
modelDir = "" // metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
preprocessKernel = MobilenetPreProccess.init(device: device) preprocessKernel = MobilenetPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 224, 224, 3]) inputDim = Dim.init(inDim: [1, 224, 224, 3])
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
import paddle_mobile
public class MobileNetCombined: Net {
@objc public override init(device: MTLDevice) {
super.init(device: device)
except = 0
modelPath = Bundle.main.path(forResource: "combined_mobilenet_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "combined_mobilenet_params", ofType: nil) ?! "para null"
inputDim = Dim.init(inDim: [1, 224, 224, 3])
// metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
}
override public func resultStr(res: ResultHolder) -> String {
return " \(res.result[0]) ... "
}
}
...@@ -13,36 +13,35 @@ ...@@ -13,36 +13,35 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import paddle_mobile
public class MobileNet_ssd_hand: Net{ public class MobileNet_ssd_hand: Net {
@objc public override init(device: MTLDevice) { @objc public override init(device: MTLDevice) {
super.init(device: device) super.init(device: device)
means = [123.68, 116.78, 103.94]
scale = 0.017
except = 2 except = 2
modelPath = Bundle.main.path(forResource: "ssd_hand_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "ssd_hand_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "ssd_hand_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "ssd_hand_params", ofType: nil) ?! "para null"
modelDir = "" // metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
preprocessKernel = MobilenetssdPreProccess.init(device: device) preprocessKernel = MobilenetssdPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 300, 300, 3]) inputDim = Dim.init(inDim: [1, 300, 300, 3])
} }
@objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) { @objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize) super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize)
means = [123.68, 116.78, 103.94]
scale = 0.017
except = 2 except = 2
modelPath = "" modelPath = ""
paramPath = "" paramPath = ""
modelDir = "" // metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
preprocessKernel = MobilenetssdPreProccess.init(device: device) preprocessKernel = MobilenetssdPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 300, 300, 3]) inputDim = Dim.init(inDim: [1, 300, 300, 3])
} }
class MobilenetssdPreProccess: CusomKernel { class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = Shape.init(inWidth: 300, inHeight: 300, inChannel: 3) let s = Shape.init(inWidth: 300, inHeight: 300, inChannel: 3)
super.init(device: device, inFunctionName: "mobilenet_ssd_preprocess", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilenet_ssd_preprocess", outputDim: s, metalLoadModel: .LoadMetalInDefaultLib, metalLibPath: nil)
} }
} }
...@@ -50,7 +49,7 @@ public class MobileNet_ssd_hand: Net{ ...@@ -50,7 +49,7 @@ public class MobileNet_ssd_hand: Net{
return " \(res)" return " \(res)"
} }
override func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder { override public func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder {
// guard let interRes = paddleMobileRes.intermediateResults else { // guard let interRes = paddleMobileRes.intermediateResults else {
// fatalError(" need have inter result ") // fatalError(" need have inter result ")
......
...@@ -13,36 +13,29 @@ ...@@ -13,36 +13,29 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import paddle_mobile
public class MobileNet_ssd_AR: Net{ public class MobileNet_ssd_AR: Net {
@objc public override init(device: MTLDevice) { @objc public override init(device: MTLDevice) {
super.init(device: device) super.init(device: device)
means = [103.94, 116.78, 123.68]
scale = 1
except = 2 except = 2
modelPath = Bundle.main.path(forResource: "ar_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "ar_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "ar_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "ar_params", ofType: nil) ?! "para null"
modelDir = ""
preprocessKernel = MobilenetssdPreProccess.init(device: device) preprocessKernel = MobilenetssdPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 160, 160, 3]) inputDim = Dim.init(inDim: [1, 160, 160, 3])
} }
@objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) { @objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize) super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize)
means = [103.94, 116.78, 123.68]
scale = 1
except = 2 except = 2
modelPath = ""
paramPath = ""
modelDir = ""
preprocessKernel = MobilenetssdPreProccess.init(device: device) preprocessKernel = MobilenetssdPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 160, 160, 3]) inputDim = Dim.init(inDim: [1, 160, 160, 3])
} }
class MobilenetssdPreProccess: CusomKernel { class MobilenetssdPreProccess: CusomKernel {
init(device: MTLDevice) { init(device: MTLDevice) {
let s = Shape.init(inWidth: 160, inHeight: 160, inChannel: 3) let s = Shape.init(inWidth: 160, inHeight: 160, inChannel: 3)
super.init(device: device, inFunctionName: "mobilent_ar_preprocess", outputDim: s, usePaddleMobileLib: false) super.init(device: device, inFunctionName: "mobilent_ar_preprocess", outputDim: s, metalLoadModel: .LoadMetalInDefaultLib, metalLibPath: nil)
} }
} }
...@@ -50,18 +43,19 @@ public class MobileNet_ssd_AR: Net{ ...@@ -50,18 +43,19 @@ public class MobileNet_ssd_AR: Net{
return " \(res.result[0])" return " \(res.result[0])"
} }
override func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder { override public func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder {
guard let interRes = paddleMobileRes.intermediateResults else { fatalError()
fatalError(" need have inter result ") // guard let interRes = paddleMobileRes.intermediateResults else {
} // fatalError(" need have inter result ")
// }
guard let scores = interRes["Scores"], scores.count > 0, let score = scores[0] as? FetchHolder else { //
fatalError(" need score ") // guard let scores = interRes["Scores"], scores.count > 0, let score = scores[0] as? FetchHolder else {
} // fatalError(" need score ")
// }
guard let bboxs = interRes["BBoxes"], bboxs.count > 0, let bbox = bboxs[0] as? FetchHolder else { //
fatalError() // guard let bboxs = interRes["BBoxes"], bboxs.count > 0, let bbox = bboxs[0] as? FetchHolder else {
} // fatalError()
// }
// let startDate = Date.init() // let startDate = Date.init()
...@@ -72,19 +66,19 @@ public class MobileNet_ssd_AR: Net{ ...@@ -72,19 +66,19 @@ public class MobileNet_ssd_AR: Net{
// //
// print((0..<bbox.capacity).map{ bbox.result[$0] }.strideArray()) // print((0..<bbox.capacity).map{ bbox.result[$0] }.strideArray())
let nmsCompute = NMSCompute.init() // let nmsCompute = NMSCompute.init()
nmsCompute.scoreThredshold = 0.25 // nmsCompute.scoreThredshold = 0.25
nmsCompute.nmsTopK = 100 // nmsCompute.nmsTopK = 100
nmsCompute.keepTopK = 100 // nmsCompute.keepTopK = 100
nmsCompute.nmsEta = 1.0 // nmsCompute.nmsEta = 1.0
nmsCompute.nmsThreshold = 0.449999988 // nmsCompute.nmsThreshold = 0.449999988
nmsCompute.background_label = 0; // nmsCompute.background_label = 0;
nmsCompute.scoreDim = [NSNumber.init(value: score.dim[0]), NSNumber.init(value: score.dim[1]), NSNumber.init(value: score.dim[2])] // nmsCompute.scoreDim = [NSNumber.init(value: score.dim[0]), NSNumber.init(value: score.dim[1]), NSNumber.init(value: score.dim[2])]
nmsCompute.bboxDim = [NSNumber.init(value: bbox.dim[0]), NSNumber.init(value: bbox.dim[1]), NSNumber.init(value: bbox.dim[2])] // nmsCompute.bboxDim = [NSNumber.init(value: bbox.dim[0]), NSNumber.init(value: bbox.dim[1]), NSNumber.init(value: bbox.dim[2])]
guard let result = nmsCompute.compute(withScore: score.result, andBBoxs: bbox.result) else { // guard let result = nmsCompute.compute(withScore: score.result, andBBoxs: bbox.result) else {
fatalError( " result error " ) // fatalError( " result error " )
} // }
let resultHolder = ResultHolder.init(inResult: result.output, inCapacity: Int(result.outputSize)) // let resultHolder = ResultHolder.init(inResult: result.output, inCapacity: Int(result.outputSize))
// for i in 0..<Int(result.outputSize) { // for i in 0..<Int(result.outputSize) {
// //
// print("i \(i) : \(result.output[i])") // print("i \(i) : \(result.output[i])")
...@@ -92,10 +86,11 @@ public class MobileNet_ssd_AR: Net{ ...@@ -92,10 +86,11 @@ public class MobileNet_ssd_AR: Net{
// print(Date.init().timeIntervalSince(startDate)) // print(Date.init().timeIntervalSince(startDate))
// print(resultHolder.result![0]) // print(resultHolder.result![0])
return resultHolder // return resultHolder
} }
override func updateProgram(program: Program) { // override func updateProgram(program: Program) {
// for i in [56, 66, 76, 86, 93, 99] { // for i in [56, 66, 76, 86, 93, 99] {
// let opDesc = program.programDesc.blocks[0].ops[i] // let opDesc = program.programDesc.blocks[0].ops[i]
// let output = opDesc.outputs["Out"]!.first! // let output = opDesc.outputs["Out"]!.first!
...@@ -148,6 +143,6 @@ public class MobileNet_ssd_AR: Net{ ...@@ -148,6 +143,6 @@ public class MobileNet_ssd_AR: Net{
// print(" split axis \(opDesc.attrs["axis"])") // print(" split axis \(opDesc.attrs["axis"])")
// } // }
// 99 // 99
} // }
} }
//
// PaddleMobile.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/9/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
...@@ -115,23 +115,3 @@ kernel void mobilent_ar_preprocess_half(texture2d<half, access::read> inTexture ...@@ -115,23 +115,3 @@ kernel void mobilent_ar_preprocess_half(texture2d<half, access::read> inTexture
const half4 inColor = (inTexture.read(gid) * 255.0 - means) * 0.017; const half4 inColor = (inTexture.read(gid) * 255.0 - means) * 0.017;
outTexture.write(half4(inColor.z, inColor.y, inColor.x, 0.0f), gid); outTexture.write(half4(inColor.z, inColor.y, inColor.x, 0.0f), gid);
} }
kernel void scale(texture2d<float, access::sample> inTexture [[texture(0)]], texture2d<float, access::write> outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
float w_stride = inTexture.get_width() / outTexture.get_width();
float h_stride = inTexture.get_height() / outTexture.get_height();
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x * w_stride, gid.y * h_stride), 0);
outTexture.write(input, gid);
}
kernel void scale_half(texture2d<float, access::sample> inTexture [[texture(0)]], texture2d<half, access::write> outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
float w_stride = inTexture.get_width() / outTexture.get_width();
float h_stride = inTexture.get_height() / outTexture.get_height();
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x * w_stride, gid.y * h_stride), 0);
outTexture.write(half4(input), gid);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Metal
import Foundation
import paddle_mobile
public class YoloNet: Net {
@objc public override init(device: MTLDevice) {
super.init(device: device)
except = 0
modelPath = Bundle.main.path(forResource: "yolo_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "yolo_params", ofType: nil) ?! "para null"
inputDim = Dim.init(inDim: [1, 416, 416, 3])
// metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
}
override public func resultStr(res: ResultHolder) -> String {
return " \(res.result[0]) ... "
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import <UIKit/UIKit.h>
@interface LoadPointerViewController : UIViewController
@end
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// LoadPointerViewController.m
// paddle-mobile-demo Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by Xiao,Haichun on 2018/9/19. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import "PaddleMobileGPU.h"
#import "LoadPointerViewController.h" #import "LoadPointerViewController.h"
#import <Metal/Metal.h>
#import "paddle-mobile-demo-Bridging-Header.h" #import "paddle-mobile-demo-Bridging-Header.h"
#import <Metal/Metal.h>
@interface LoadPointerViewController () @interface LoadPointerViewController ()
@property (strong, nonatomic) id<MTLDevice> device; @property (strong, nonatomic) id<MTLDevice> device;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
@interface OCDemoViewController : NSObject
@end
NS_ASSUME_NONNULL_END
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import "OCDemoViewController.h"
@implementation OCDemoViewController
@end
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
typedef enum : NSUInteger { typedef enum : NSUInteger {
MobileNetType, SuperResolutionNetType,
MobileNetSSDType, MobileNetSSDType
GenetType,
} NetType; } NetType;
@interface PaddleMobileGPUResult: NSObject @interface PaddleMobileGPUResult: NSObject
......
...@@ -12,11 +12,10 @@ ...@@ -12,11 +12,10 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#import "paddle_mobile.h"
#import "PaddleMobileGPU.h" #import "PaddleMobileGPU.h"
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import <paddle_mobile/paddle_mobile-Swift.h> #import <paddle_mobile_demo-Swift.h>
@implementation ModelConfig @implementation ModelConfig
@end @end
...@@ -53,12 +52,10 @@ ...@@ -53,12 +52,10 @@
self = [super init]; self = [super init];
if (self) { if (self) {
Net *net = nil; Net *net = nil;
if (netType == GenetType) { if (netType == SuperResolutionNetType) {
net = [[Genet alloc] initWithDevice:queue.device paramPointer:config.paramPointer paramSize:config.paramSize modePointer:config.modelPointer modelSize:config.modelSize]; net = [[SuperResolutionNet alloc] initWithDevice:queue.device];
} else if (netType == MobileNetSSDType) { } else if (netType == MobileNetSSDType) {
net = [[MobileNet_ssd_AR alloc] initWithDevice:queue.device paramPointer:config.paramPointer paramSize:config.paramSize modePointer:config.modelPointer modelSize:config.modelSize]; net = [[MobileNet_ssd_AR alloc] initWithDevice:queue.device paramPointer:config.paramPointer paramSize:config.paramSize modePointer:config.modelPointer modelSize:config.modelSize];
} else if (netType == MobileNetType) {
} }
runner = [[Runner alloc] initInNet:net commandQueue:queue]; runner = [[Runner alloc] initInNet:net commandQueue:queue];
} }
......
...@@ -13,31 +13,25 @@ ...@@ -13,31 +13,25 @@
limitations under the License. */ limitations under the License. */
import Foundation import Foundation
import paddle_mobile
@objc public class SuperResolutionNet: Net{
public class SuperResolutionNet: Net{
override public func resultStr(res: ResultHolder) -> String { override public func resultStr(res: ResultHolder) -> String {
return "未实现" return "未实现"
} }
override public init(device: MTLDevice) { @objc override public init(device: MTLDevice) {
super.init(device: device) super.init(device: device)
means = [0.0, 0.0, 0.0]
scale = 1.0
except = 0 except = 0
modelPath = Bundle.main.path(forResource: "super_model", ofType: nil) ?! "model null" modelPath = Bundle.main.path(forResource: "super_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "super_params", ofType: nil) ?! "para null" paramPath = Bundle.main.path(forResource: "super_params", ofType: nil) ?! "para null"
modelDir = ""
preprocessKernel = nil preprocessKernel = nil
// inputDim_ = Dim.init(inDim: [1, Int(552 * 1.414), Int(310 * 1.414), 1]) inputDim = Dim.init(inDim: [1, 224, 224, 1])
inputDim_ = Dim.init(inDim: [1, 224, 224, 1]) // metalLoadMode = .LoadMetalInCustomMetalLib
// metalLibPath = Bundle.main.path(forResource: "PaddleMobileMetal", ofType: "metallib") ?! " can't be nil "
} }
override func updateProgram(program: Program) { override public func updateProgram(program: Program) {
guard needUpdateProgram else {
return
}
// n h w c // n h w c
for block in program.programDesc.blocks { for block in program.programDesc.blocks {
for varDesc in block.vars { for varDesc in block.vars {
...@@ -47,8 +41,9 @@ public class SuperResolutionNet: Net{ ...@@ -47,8 +41,9 @@ public class SuperResolutionNet: Net{
if let texture = varEle as? Texture { if let texture = varEle as? Texture {
let newDim = Dim.init(inDim: [texture.dim[0], inputDim[1], inputDim[2], texture.tensorDim[1]]) let newDim = Dim.init(inDim: [texture.dim[0], inputDim[1], inputDim[2], texture.tensorDim[1]])
print(" var desc name " + varDesc.name + " new dim" + "\(newDim)") print(" var desc name " + varDesc.name + " new dim" + "\(newDim)")
texture.updateDims(inTensorDim: Dim.init(inDim: [texture.tensorDim[0], texture.tensorDim[1], inputDim[1], inputDim[2]]), inDim: newDim) texture.updateDims(inTensorDim: Dim.init(inDim: [texture.tensorDim[0], texture.tensorDim[1], inputDim[1], inputDim[2]]), inDim: newDim)
texture.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: computePrecision) texture.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision)
let output: FetchHolder = program.scope.output() as! FetchHolder let output: FetchHolder = program.scope.output() as! FetchHolder
output.dim = newDim output.dim = newDim
...@@ -60,7 +55,6 @@ public class SuperResolutionNet: Net{ ...@@ -60,7 +55,6 @@ public class SuperResolutionNet: Net{
} }
} }
} }
needUpdateProgram = false
} }
} }
...@@ -18,7 +18,6 @@ import CoreMedia ...@@ -18,7 +18,6 @@ import CoreMedia
import paddle_mobile import paddle_mobile
import MetalPerformanceShaders import MetalPerformanceShaders
class FileReader { class FileReader {
let file: UnsafeMutablePointer<FILE> let file: UnsafeMutablePointer<FILE>
let fileSize: Int let fileSize: Int
...@@ -53,16 +52,12 @@ enum Platform { ...@@ -53,16 +52,12 @@ enum Platform {
let platformSupport: [(Platform, String)] = [(.GPU, "GPU")] let platformSupport: [(Platform, String)] = [(.GPU, "GPU")]
enum SupportModel: String{ enum SupportModel: String{
// case mobilenet = "mobilenet"
// case mobilenet_ssd = "mobilenetssd"
case yolo = "yolo" case yolo = "yolo"
case mobilenet_combined = "mobilenet_combined" case mobilenet_combined = "mobilenet_combined"
case super_resolution = "superresoltion" case super_resolution = "superresoltion"
case mobilenet = "mobilenet" case mobilenet = "mobilenet"
static func supportedModels() -> [SupportModel] { static func supportedModels() -> [SupportModel] {
// .mobilenet,
// .mobilenet_ssd,
return [.super_resolution, .yolo, .mobilenet_combined, .mobilenet] return [.super_resolution, .yolo, .mobilenet_combined, .mobilenet]
} }
} }
...@@ -94,24 +89,25 @@ class ViewController: UIViewController { ...@@ -94,24 +89,25 @@ class ViewController: UIViewController {
@IBAction func loadAct(_ sender: Any) { @IBAction func loadAct(_ sender: Any) {
runner = Runner.init(inNet: netSupport[modelType]!, commandQueue: MetalHelper.shared.queue) runner = Runner.init(inNet: netSupport[modelType]!, commandQueue: MetalHelper.shared.queue)
if platform == .GPU { if platform == .GPU {
let filePath = Bundle.main.path(forResource: "mingren_input_data", ofType: nil) // let filePath = Bundle.main.path(forResource: "mingren_input_data", ofType: nil)
let fileReader = try! FileReader.init(paramPath: filePath!) // let fileReader = try! FileReader.init(paramPath: filePath!)
let pointer: UnsafeMutablePointer<Float32> = fileReader.read() // let pointer: UnsafeMutablePointer<Float32> = fileReader.read()
//
//
let buffer = MetalHelper.shared.device.makeBuffer(length: fileReader.fileSize, options: .storageModeShared) // let buffer = MetalHelper.shared.device.makeBuffer(length: fileReader.fileSize, options: .storageModeShared)
//
buffer?.contents().copyMemory(from: pointer, byteCount: fileReader.fileSize) // buffer?.contents().copyMemory(from: pointer, byteCount: fileReader.fileSize)
if self.toPredictTexture == nil { if self.toPredictTexture == nil {
runner.getTexture(inBuffer: buffer!) { [weak self] (texture) in // runner.getTexture(inBuffer: buffer!) { [weak self] (texture) in
// self?.toPredictTexture = texture
// }
runner.getTexture(image: selectImage!.cgImage!) { [weak self] (texture) in
self?.toPredictTexture = texture self?.toPredictTexture = texture
} }
// runner.getTexture(image: selectImage!.cgImage!) { [weak self] (texture) in
// }
} }
} else { } else {
fatalError( " unsupport " ) fatalError( " unsupport " )
...@@ -178,13 +174,14 @@ class ViewController: UIViewController { ...@@ -178,13 +174,14 @@ class ViewController: UIViewController {
modelPickerView.dataSource = self modelPickerView.dataSource = self
threadPickerView.delegate = self threadPickerView.delegate = self
threadPickerView.dataSource = self threadPickerView.dataSource = self
if let image = UIImage.init(named: "test.jpg") { if let image = UIImage.init(named: "classify-img-output.png") {
selectImage = image selectImage = image
selectImageView.image = image selectImageView.image = image
} else { } else {
print("请添加测试图片") print("请添加测试图片")
} }
GlobalConfig.shared.computePrecision = .Float32
// if platform == .CPU { // if platform == .CPU {
// inputPointer = runner.preproccess(image: selectImage!.cgImage!) // inputPointer = runner.preproccess(image: selectImage!.cgImage!)
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; }; 4AF9288421357BE3005B6C3A /* Elementwise.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9288321357BE3005B6C3A /* Elementwise.metal */; };
C28FDF8421B7858F0054EFAC /* MobileNetCombined.swift in Sources */ = {isa = PBXBuildFile; fileRef = C28FDF8221B7858F0054EFAC /* MobileNetCombined.swift */; };
C28FDF8521B7858F0054EFAC /* YoloNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = C28FDF8321B7858F0054EFAC /* YoloNet.swift */; };
C28FE02F21BA68C00054EFAC /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02C21BA68C00054EFAC /* Metal.framework */; }; C28FE02F21BA68C00054EFAC /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02C21BA68C00054EFAC /* Metal.framework */; };
C28FE03021BA68C00054EFAC /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */; }; C28FE03021BA68C00054EFAC /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */; };
C28FE03121BA68C00054EFAC /* MetalKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02E21BA68C00054EFAC /* MetalKit.framework */; }; C28FE03121BA68C00054EFAC /* MetalKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02E21BA68C00054EFAC /* MetalKit.framework */; };
...@@ -64,18 +62,16 @@ ...@@ -64,18 +62,16 @@
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; }; FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; }; FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1CF3F621D4B4C400F7392E /* Runner.swift */; }; FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1CF3F621D4B4C400F7392E /* Runner.swift */; };
FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */ = {isa = PBXBuildFile; fileRef = FC292C5321421B2E00CF622F /* PaddleMobileGPU.h */; settings = {ATTRIBUTES = (Public, ); }; }; FC2BFCC221DF2F9100C262B2 /* GlobalConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFCC121DF2F9100C262B2 /* GlobalConfig.swift */; };
FC292C5621421B4600CF622F /* PaddleMobileGPU.m in Sources */ = {isa = PBXBuildFile; fileRef = FC292C5521421B4600CF622F /* PaddleMobileGPU.m */; }; FC2BFD4621DF685F00C262B2 /* Scale.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD4521DF685F00C262B2 /* Scale.swift */; };
FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */ = {isa = PBXBuildFile; fileRef = FC292C7C214255BC00CF622F /* CPUCompute.mm */; }; FC2BFD4A21DF81DE00C262B2 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD4921DF81DE00C262B2 /* Kernel.swift */; };
FC292C82214255BD00CF622F /* MobileNetSSD.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC292C7E214255BC00CF622F /* MobileNetSSD.swift */; }; FC2BFD4E21DF820B00C262B2 /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD4D21DF820A00C262B2 /* ConvAddBatchNormReluOp.swift */; };
FC292C872142624800CF622F /* Genet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC292C862142624800CF622F /* Genet.swift */; }; FC2BFD5121DF8E0400C262B2 /* Scale.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC2BFD5021DF8E0400C262B2 /* Scale.metal */; };
FC33B0F02147659000714A93 /* MobileNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC33B0EF2147659000714A93 /* MobileNet.swift */; };
FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; }; FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */; };
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; }; FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; };
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; }; FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; }; FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; }; FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC704C2721D2385100F98BAB /* SuperResolutionNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC704C2621D2385100F98BAB /* SuperResolutionNet.swift */; };
FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */; }; FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */; };
FC803BC1214CB77A0094B8E5 /* ConvAddPreluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */; }; FC803BC1214CB77A0094B8E5 /* ConvAddPreluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */; };
FC803BC3214CB79C0094B8E5 /* ConvAddPreluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC803BC2214CB79C0094B8E5 /* ConvAddPreluKernel.metal */; }; FC803BC3214CB79C0094B8E5 /* ConvAddPreluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC803BC2214CB79C0094B8E5 /* ConvAddPreluKernel.metal */; };
...@@ -85,8 +81,6 @@ ...@@ -85,8 +81,6 @@
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; }; FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
FC9797C921D6101D00F2FD90 /* ResizeBilinearOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */; }; FC9797C921D6101D00F2FD90 /* ResizeBilinearOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */; };
FC9797CB21D6102D00F2FD90 /* ResizeBilinearKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */; }; FC9797CB21D6102D00F2FD90 /* ResizeBilinearKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */; };
FC9797CD21D61B2E00F2FD90 /* CPUCompute.h in Headers */ = {isa = PBXBuildFile; fileRef = FC292C7D214255BC00CF622F /* CPUCompute.h */; settings = {ATTRIBUTES = (Public, ); }; };
FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9A19E22148C31300CD9CBF /* MobilenetSSD_AR.swift */; };
FC9C2A0D21D3D185005856C6 /* FetchKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */; }; FC9C2A0D21D3D185005856C6 /* FetchKernel.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; }; FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; }; FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
...@@ -138,9 +132,7 @@ ...@@ -138,9 +132,7 @@
FCE9D7B9214FAA4800B520C3 /* NMSFetchResultKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCE9D7B8214FAA4800B520C3 /* NMSFetchResultKernel.metal */; }; FCE9D7B9214FAA4800B520C3 /* NMSFetchResultKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCE9D7B8214FAA4800B520C3 /* NMSFetchResultKernel.metal */; };
FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCEB6849212F00DB00D2448E /* PreluKernel.metal */; }; FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCEB6849212F00DB00D2448E /* PreluKernel.metal */; };
FCEB684C212F093800D2448E /* PreluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEB684B212F093800D2448E /* PreluOp.swift */; }; FCEB684C212F093800D2448E /* PreluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEB684B212F093800D2448E /* PreluOp.swift */; };
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */; };
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; }; FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; };
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
...@@ -166,8 +158,6 @@ ...@@ -166,8 +158,6 @@
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; }; 4AF9288321357BE3005B6C3A /* Elementwise.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Elementwise.metal; sourceTree = "<group>"; };
C28FDF8221B7858F0054EFAC /* MobileNetCombined.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileNetCombined.swift; sourceTree = "<group>"; };
C28FDF8321B7858F0054EFAC /* YoloNet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = YoloNet.swift; sourceTree = "<group>"; };
C28FE02C21BA68C00054EFAC /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; C28FE02C21BA68C00054EFAC /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; }; C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; };
C28FE02E21BA68C00054EFAC /* MetalKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalKit.framework; path = System/Library/Frameworks/MetalKit.framework; sourceTree = SDKROOT; }; C28FE02E21BA68C00054EFAC /* MetalKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalKit.framework; path = System/Library/Frameworks/MetalKit.framework; sourceTree = SDKROOT; };
...@@ -206,20 +196,17 @@ ...@@ -206,20 +196,17 @@
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; }; FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1CF3F621D4B4C400F7392E /* Runner.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Runner.swift; sourceTree = "<group>"; }; FC1CF3F621D4B4C400F7392E /* Runner.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Runner.swift; sourceTree = "<group>"; };
FC292C5321421B2E00CF622F /* PaddleMobileGPU.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PaddleMobileGPU.h; sourceTree = "<group>"; }; FC2BFCC121DF2F9100C262B2 /* GlobalConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GlobalConfig.swift; sourceTree = "<group>"; };
FC292C5521421B4600CF622F /* PaddleMobileGPU.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = PaddleMobileGPU.m; sourceTree = "<group>"; }; FC2BFD4521DF685F00C262B2 /* Scale.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Scale.swift; sourceTree = "<group>"; };
FC292C7C214255BC00CF622F /* CPUCompute.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CPUCompute.mm; sourceTree = "<group>"; }; FC2BFD4921DF81DE00C262B2 /* Kernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Kernel.swift; sourceTree = "<group>"; };
FC292C7D214255BC00CF622F /* CPUCompute.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = CPUCompute.h; sourceTree = "<group>"; }; FC2BFD4D21DF820A00C262B2 /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluOp.swift; sourceTree = "<group>"; };
FC292C7E214255BC00CF622F /* MobileNetSSD.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobileNetSSD.swift; sourceTree = "<group>"; }; FC2BFD5021DF8E0400C262B2 /* Scale.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Scale.metal; sourceTree = "<group>"; };
FC292C862142624800CF622F /* Genet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Genet.swift; sourceTree = "<group>"; };
FC33B0EF2147659000714A93 /* MobileNet.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MobileNet.swift; sourceTree = "<group>"; };
FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = "<group>"; }; FC3602CB2108819F00FACB58 /* PaddleMobileUnitTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PaddleMobileUnitTest.swift; sourceTree = "<group>"; };
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; }; FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; };
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; }; FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; };
FC4FD97D2140F2C30073E130 /* libstdc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libstdc++.tbd"; path = "usr/lib/libstdc++.tbd"; sourceTree = SDKROOT; }; FC4FD97D2140F2C30073E130 /* libstdc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libstdc++.tbd"; path = "usr/lib/libstdc++.tbd"; sourceTree = SDKROOT; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; }; FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; }; FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC704C2621D2385100F98BAB /* SuperResolutionNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuperResolutionNet.swift; sourceTree = "<group>"; };
FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddPreluOp.swift; sourceTree = "<group>"; }; FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddPreluOp.swift; sourceTree = "<group>"; };
FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddPreluKernel.swift; sourceTree = "<group>"; }; FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddPreluKernel.swift; sourceTree = "<group>"; };
FC803BC2214CB79C0094B8E5 /* ConvAddPreluKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddPreluKernel.metal; sourceTree = "<group>"; }; FC803BC2214CB79C0094B8E5 /* ConvAddPreluKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddPreluKernel.metal; sourceTree = "<group>"; };
...@@ -229,7 +216,6 @@ ...@@ -229,7 +216,6 @@
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; }; FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ResizeBilinearOp.swift; sourceTree = "<group>"; }; FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ResizeBilinearOp.swift; sourceTree = "<group>"; };
FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ResizeBilinearKernel.swift; sourceTree = "<group>"; }; FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ResizeBilinearKernel.swift; sourceTree = "<group>"; };
FC9A19E22148C31300CD9CBF /* MobilenetSSD_AR.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MobilenetSSD_AR.swift; sourceTree = "<group>"; };
FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = FetchKernel.inc.metal; sourceTree = "<group>"; }; FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = FetchKernel.inc.metal; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; }; FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; }; FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
...@@ -281,9 +267,7 @@ ...@@ -281,9 +267,7 @@
FCE9D7B8214FAA4800B520C3 /* NMSFetchResultKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = NMSFetchResultKernel.metal; sourceTree = "<group>"; }; FCE9D7B8214FAA4800B520C3 /* NMSFetchResultKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = NMSFetchResultKernel.metal; sourceTree = "<group>"; };
FCEB6849212F00DB00D2448E /* PreluKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = PreluKernel.metal; sourceTree = "<group>"; }; FCEB6849212F00DB00D2448E /* PreluKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = PreluKernel.metal; sourceTree = "<group>"; };
FCEB684B212F093800D2448E /* PreluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreluOp.swift; sourceTree = "<group>"; }; FCEB684B212F093800D2448E /* PreluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreluOp.swift; sourceTree = "<group>"; };
FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = ConvAddBatchNormReluOp.swift; path = "paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"; sourceTree = SOURCE_ROOT; };
FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = "<group>"; }; FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = "<group>"; };
FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Base/Kernel.swift"; sourceTree = SOURCE_ROOT; };
/* End PBXFileReference section */ /* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */ /* Begin PBXFrameworksBuildPhase section */
...@@ -343,17 +327,10 @@ ...@@ -343,17 +327,10 @@
FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = { FC039B6C20E11C3C0081E9F8 /* paddle-mobile */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC9797CC21D61A5500F2FD90 /* CustomNet */, FC2BFD4721DF818000C262B2 /* API */,
FCE9D7B6214F869000B520C3 /* Net.swift */, FC2BFD4821DF818000C262B2 /* Src */,
FC292C5521421B4600CF622F /* PaddleMobileGPU.m */,
FC292C5321421B2E00CF622F /* PaddleMobileGPU.h */,
FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */,
FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */, FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */,
FC039B6E20E11C3C0081E9F8 /* Info.plist */, FC039B6E20E11C3C0081E9F8 /* Info.plist */,
FC1CF3F621D4B4C400F7392E /* Runner.swift */,
); );
path = "paddle-mobile"; path = "paddle-mobile";
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -386,10 +363,10 @@ ...@@ -386,10 +363,10 @@
FC039BA320E11CBC0081E9F8 /* Operators */ = { FC039BA320E11CBC0081E9F8 /* Operators */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */,
FC086BA520E67E8500D85EF7 /* Kernels */, FC086BA520E67E8500D85EF7 /* Kernels */,
FCD592FA20E248EC00252966 /* Base */, FCD592FA20E248EC00252966 /* Base */,
FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */, FC9797C821D6101D00F2FD90 /* ResizeBilinearOp.swift */,
FC2BFD4D21DF820A00C262B2 /* ConvAddBatchNormReluOp.swift */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */, FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */, FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
...@@ -441,9 +418,9 @@ ...@@ -441,9 +418,9 @@
FC086BA520E67E8500D85EF7 /* Kernels */ = { FC086BA520E67E8500D85EF7 /* Kernels */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */,
FCDDC6CD212FE02100E5EF74 /* Base */, FCDDC6CD212FE02100E5EF74 /* Base */,
FCEB6837212F00B100D2448E /* metal */, FCEB6837212F00B100D2448E /* metal */,
FC9797CA21D6102D00F2FD90 /* ResizeBilinearKernel.swift */,
FCDDC6C7212FA3CA00E5EF74 /* ConvTransposeKernel.swift */, FCDDC6C7212FA3CA00E5EF74 /* ConvTransposeKernel.swift */,
FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */, FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */,
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */, FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */,
...@@ -469,24 +446,30 @@ ...@@ -469,24 +446,30 @@
FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */, FC803BC0214CB77A0094B8E5 /* ConvAddPreluKernel.swift */,
FCE3A1AA2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift */, FCE3A1AA2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift */,
FCE3A1AE2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift */, FCE3A1AE2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift */,
FC2BFD4521DF685F00C262B2 /* Scale.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FC9797CC21D61A5500F2FD90 /* CustomNet */ = { FC2BFD4721DF818000C262B2 /* API */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC292C7C214255BC00CF622F /* CPUCompute.mm */, FCE9D7B6214F869000B520C3 /* Net.swift */,
FC292C7D214255BC00CF622F /* CPUCompute.h */, FC1CF3F621D4B4C400F7392E /* Runner.swift */,
FC704C2621D2385100F98BAB /* SuperResolutionNet.swift */, FC2BFCC121DF2F9100C262B2 /* GlobalConfig.swift */,
C28FDF8221B7858F0054EFAC /* MobileNetCombined.swift */, );
C28FDF8321B7858F0054EFAC /* YoloNet.swift */, path = API;
FC9A19E22148C31300CD9CBF /* MobilenetSSD_AR.swift */, sourceTree = "<group>";
FC33B0EF2147659000714A93 /* MobileNet.swift */, };
FC292C862142624800CF622F /* Genet.swift */, FC2BFD4821DF818000C262B2 /* Src */ = {
FC292C7E214255BC00CF622F /* MobileNetSSD.swift */, isa = PBXGroup;
children = (
FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */,
); );
path = CustomNet; path = Src;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FCD592FA20E248EC00252966 /* Base */ = { FCD592FA20E248EC00252966 /* Base */ = {
...@@ -502,7 +485,7 @@ ...@@ -502,7 +485,7 @@
FCDDC6CD212FE02100E5EF74 /* Base */ = { FCDDC6CD212FE02100E5EF74 /* Base */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FCF2D73720E64E70007AC5F5 /* Kernel.swift */, FC2BFD4921DF81DE00C262B2 /* Kernel.swift */,
); );
path = Base; path = Base;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -547,6 +530,7 @@ ...@@ -547,6 +530,7 @@
FCE3A1B02153E90F00C37CDE /* ElementwiseAddPreluKernel.inc.metal */, FCE3A1B02153E90F00C37CDE /* ElementwiseAddPreluKernel.inc.metal */,
FCE3A1B22153E91900C37CDE /* ElementwiseAddPreluKernel.metal */, FCE3A1B22153E91900C37CDE /* ElementwiseAddPreluKernel.metal */,
FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */, FC9C2A0C21D3D185005856C6 /* FetchKernel.inc.metal */,
FC2BFD5021DF8E0400C262B2 /* Scale.metal */,
); );
path = metal; path = metal;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -558,9 +542,7 @@ ...@@ -558,9 +542,7 @@
isa = PBXHeadersBuildPhase; isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC292C5421421B2F00CF622F /* PaddleMobileGPU.h in Headers */,
FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */, FC039B6F20E11C3C0081E9F8 /* paddle_mobile.h in Headers */,
FC9797CD21D61B2E00F2FD90 /* CPUCompute.h in Headers */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
}; };
...@@ -678,7 +660,6 @@ ...@@ -678,7 +660,6 @@
FCE3A1AB2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift in Sources */, FCE3A1AB2153DE8C00C37CDE /* ConvAddAddPreluKernel.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */, FC3602CC2108819F00FACB58 /* PaddleMobileUnitTest.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FCDDC6CC212FDFDB00E5EF74 /* ReluKernel.metal in Sources */, FCDDC6CC212FDFDB00E5EF74 /* ReluKernel.metal in Sources */,
FC0226562138F33800F395E2 /* TransposeKernel.metal in Sources */, FC0226562138F33800F395E2 /* TransposeKernel.metal in Sources */,
FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */, FCDDC6C6212F9FB800E5EF74 /* PreluKernel.swift in Sources */,
...@@ -686,13 +667,10 @@ ...@@ -686,13 +667,10 @@
FCA67CD52138272900BD58AA /* ConvAddMetal.metal in Sources */, FCA67CD52138272900BD58AA /* ConvAddMetal.metal in Sources */,
FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */, FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */,
4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */, 4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */,
FC292C81214255BD00CF622F /* CPUCompute.mm in Sources */,
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */,
4AA1EAAC214F55C800D0F791 /* Softmax.inc.metal in Sources */, 4AA1EAAC214F55C800D0F791 /* Softmax.inc.metal in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */, FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */, 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */,
FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */, FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */,
FC33B0F02147659000714A93 /* MobileNet.swift in Sources */,
FCEB684C212F093800D2448E /* PreluOp.swift in Sources */, FCEB684C212F093800D2448E /* PreluOp.swift in Sources */,
4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */, 4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */,
FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */, FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */,
...@@ -704,17 +682,19 @@ ...@@ -704,17 +682,19 @@
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */, FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */, FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
C28FDF8521B7858F0054EFAC /* YoloNet.swift in Sources */,
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */, FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */, FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */,
FCA3A1632132A4AC00084FE5 /* ReshapeKernel.metal in Sources */, FCA3A1632132A4AC00084FE5 /* ReshapeKernel.metal in Sources */,
FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */, FCBCCC592122F42700D94F7E /* ConvBNReluOp.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FCCED5E121D71FC000BE8D5F /* PoolKernel.inc.metal in Sources */, FCCED5E121D71FC000BE8D5F /* PoolKernel.inc.metal in Sources */,
FC2BFD4A21DF81DE00C262B2 /* Kernel.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */, FC9D038420E23B01000F735A /* Texture.swift in Sources */,
FCE3A1B32153E91900C37CDE /* ElementwiseAddPreluKernel.metal in Sources */, FCE3A1B32153E91900C37CDE /* ElementwiseAddPreluKernel.metal in Sources */,
FC2BFD4E21DF820B00C262B2 /* ConvAddBatchNormReluOp.swift in Sources */,
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */, 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */,
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */, 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */,
FC2BFCC221DF2F9100C262B2 /* GlobalConfig.swift in Sources */,
FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */, FCBCCC652122FCD700D94F7E /* TransposeOp.swift in Sources */,
4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */, 4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */, FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
...@@ -724,13 +704,11 @@ ...@@ -724,13 +704,11 @@
FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */, FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */,
FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */, FC1CF3F721D4B4C400F7392E /* Runner.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC292C5621421B4600CF622F /* PaddleMobileGPU.m in Sources */,
FCD04E6620F314C50007374F /* PoolOp.swift in Sources */, FCD04E6620F314C50007374F /* PoolOp.swift in Sources */,
FCE9D7B9214FAA4800B520C3 /* NMSFetchResultKernel.metal in Sources */, FCE9D7B9214FAA4800B520C3 /* NMSFetchResultKernel.metal in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */, FCBCCC6F2123097100D94F7E /* MulticlassNMSOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC292C872142624800CF622F /* Genet.swift in Sources */,
FC803BC5214CB8F00094B8E5 /* ConvAddPrelu.inc.metal in Sources */, FC803BC5214CB8F00094B8E5 /* ConvAddPrelu.inc.metal in Sources */,
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */, 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */,
FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */, FCBCCC632122FCC000D94F7E /* TransposeKernel.swift in Sources */,
...@@ -748,23 +726,20 @@ ...@@ -748,23 +726,20 @@
FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */, FCBCCC5D2122F8A100D94F7E /* DepthwiseConvOp.swift in Sources */,
FCE3A1AF2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift in Sources */, FCE3A1AF2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift in Sources */,
FCE9D7B7214F869000B520C3 /* Net.swift in Sources */, FCE9D7B7214F869000B520C3 /* Net.swift in Sources */,
FC704C2721D2385100F98BAB /* SuperResolutionNet.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
C28FDF8421B7858F0054EFAC /* MobileNetCombined.swift in Sources */,
FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */, FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */,
FC292C82214255BD00CF622F /* MobileNetSSD.swift in Sources */,
FCBCCC612122FBDF00D94F7E /* PriorBoxKernel.swift in Sources */, FCBCCC612122FBDF00D94F7E /* PriorBoxKernel.swift in Sources */,
FCBCCC5F2122FB3B00D94F7E /* PriorBoxOp.swift in Sources */, FCBCCC5F2122FB3B00D94F7E /* PriorBoxOp.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FCA67B1721364EF000BD58AA /* ConvTransposeKernel.metal in Sources */, FCA67B1721364EF000BD58AA /* ConvTransposeKernel.metal in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC2BFD5121DF8E0400C262B2 /* Scale.metal in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */, FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */, FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */,
4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */, 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.inc.metal in Sources */,
FC9A19E32148C31300CD9CBF /* MobilenetSSD_AR.swift in Sources */,
FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */, FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
FCE3A1A92153DE5100C37CDE /* ConvAddAddPreluOp.swift in Sources */, FCE3A1A92153DE5100C37CDE /* ConvAddAddPreluOp.swift in Sources */,
...@@ -780,6 +755,7 @@ ...@@ -780,6 +755,7 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */, FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */,
FC9797C921D6101D00F2FD90 /* ResizeBilinearOp.swift in Sources */, FC9797C921D6101D00F2FD90 /* ResizeBilinearOp.swift in Sources */,
4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */, 4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */,
FC2BFD4621DF685F00C262B2 /* Scale.swift in Sources */,
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
...@@ -919,7 +895,7 @@ ...@@ -919,7 +895,7 @@
DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1; DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath"; DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_BITCODE = NO; ENABLE_BITCODE = YES;
INFOPLIST_FILE = "paddle-mobile/Info.plist"; INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
...@@ -956,7 +932,7 @@ ...@@ -956,7 +932,7 @@
DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1; DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath"; DYLIB_INSTALL_NAME_BASE = "@rpath";
ENABLE_BITCODE = NO; ENABLE_BITCODE = YES;
INFOPLIST_FILE = "paddle-mobile/Info.plist"; INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0; IPHONEOS_DEPLOYMENT_TARGET = 9.0;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
@objc public enum MetalLoadMode: Int {
case
LoadMetalInPaddleMobile = 1, // 使用 paddle-mobile 中的 metal 代码
LoadMetalInDefaultLib = 2, // 使用 main bundle 中的 metal 代码
LoadMetalInCustomMetalLib = 3 // 使用 metal 库文件
}
@objc public enum ComputePrecision: Int {
case
Float32 = 1,
Float16 = 2
}
@objc public class GlobalConfig: NSObject {
/// 单例
@objc public static let shared: GlobalConfig = GlobalConfig.init()
/// 运算精度, runner 生命周期中不可变
@objc public var computePrecision: ComputePrecision = .Float16
}
...@@ -12,58 +12,52 @@ ...@@ -12,58 +12,52 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
import Foundation
import Metal import Metal
import Foundation
public class ResultHolder: NSObject { /// 网络的基类, 参数已经给了默认值,请在子类实现中修改需要改的参数
@objc public let result: UnsafeMutablePointer<Float32> @objc open class Net: NSObject {
@objc public let capacity: Int
init(inResult: UnsafeMutablePointer<Float32>, inCapacity: Int) { /// 默认为0, 如果指定个数, 后边 except 个op不使用 GPU 运算, 中间结果会通过 fetchResult 传参过来
result = inResult @objc public var except: Int = 0
capacity = inCapacity
} /// 预处理 kernel, 如果输入图像需要预处理, 则指定预处理 kernel
@objc public var preprocessKernel: CusomKernel? = nil
@objc public func releasePointer() { // 以下四个参数为从内存中读取模型时用到的参数
result.deinitialize(count: capacity) /// 模型在内存中的指针
result.deallocate() @objc public var modelPointer: UnsafeMutableRawPointer? = nil
}
}
public class Net: NSObject {
var except: Int = 0
// for CPU /// 模型大小 单位: 字节
var means: [Float] = [] @objc public var modelSize: Int = 0
var scale: Float = 0.0
var needUpdateProgram = true /// 权重参数在内存中的指针
@objc public var paramPointer: UnsafeMutableRawPointer? = nil
public var inputDim: Dim { /// 权重大小 单位: 字节
get{ @objc public var paramSize: Int = 0
return inputDim_
}
set{
if inputDim_ != newValue {
needUpdateProgram = true
}
inputDim_ = newValue
}
}
var inputDim_: Dim = Dim.init(inDim: []) // 以下两个为从文件中读取模型时用到的参数
var preprocessKernel: CusomKernel? = nil /// 模型文件路径
var paramPointer: UnsafeMutableRawPointer? = nil @objc public var modelPath: String? = nil
var paramSize: Int = 0
var modelPointer: UnsafeMutableRawPointer? = nil /// 权重文件路径
var modelSize: Int = 0 @objc public var paramPath: String? = nil
var modelPath: String = ""
var paramPath: String = "" /// 代表着 GPU 处理器
var modelDir: String = "" @objc public let device: MTLDevice
let device: MTLDevice
/// metal 代码加载方式 注意: 如果静态库只能使用 LoadMetalInDefaultLib LoadMetalInCustomMetalLib 进行 load metal 代码
@objc public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) { @objc public var metalLoadMode: MetalLoadMode = .LoadMetalInPaddleMobile
/// 当 metalLoadMode 为 LoadMetalInCustomMetalLib 时, metal library 路径不能为空
@objc public var metalLibPath: String? = nil
/// 输入维度,按照 n h w c 方式传入
@objc public var inputDim: Dim = Dim.init(inDim: [])
@objc public init(device: MTLDevice, paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
self.paramPointer = paramPointer self.paramPointer = paramPointer
self.paramSize = paramSize self.paramSize = paramSize
self.modelPointer = modePointer self.modelPointer = modePointer
...@@ -76,23 +70,18 @@ public class Net: NSObject { ...@@ -76,23 +70,18 @@ public class Net: NSObject {
self.device = device self.device = device
super.init() super.init()
} }
@objc public func updateInputDim(inDim: [Int]) {
inputDim = Dim.init(inDim: inDim)
}
public func resultStr(res: ResultHolder) -> String { @objc open func resultStr(res: ResultHolder) -> String {
fatalError() fatalError()
} }
func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder { @objc open func fetchResult(paddleMobileRes: GPUResultHolder) -> ResultHolder {
guard let inResPointer = paddleMobileRes.resultPointer else { guard let inResPointer = paddleMobileRes.resultPointer else {
fatalError() fatalError()
} }
return ResultHolder.init(inResult: inResPointer, inCapacity: paddleMobileRes.capacity) return ResultHolder.init(inResult: inResPointer, inCapacity: paddleMobileRes.capacity)
} }
func updateProgram(program: Program) { open func updateProgram(program: Program) {
} }
} }
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Runner.swift
// paddle-mobile Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/12/27. You may obtain a copy of the License at
// Copyright © 2018 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import MetalKit import MetalKit
import Foundation import Foundation
class ScaleKernel: CusomKernel { @objc public class ResultHolder: NSObject {
init(device: MTLDevice, shape: Shape) { @objc public let result: UnsafeMutablePointer<Float32>
if computePrecision == .Float32 { @objc public let capacity: Int
super.init(device: device, inFunctionName: "scale", outputDim: shape, usePaddleMobileLib: false)
} else if computePrecision == .Float16 { init(inResult: UnsafeMutablePointer<Float32>, inCapacity: Int) {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, usePaddleMobileLib: false) result = inResult
} else { capacity = inCapacity
fatalError(" unsupport ") }
}
@objc public func releasePointer() {
result.deinitialize(count: capacity)
result.deallocate()
} }
} }
...@@ -29,15 +38,12 @@ class ScaleKernel: CusomKernel { ...@@ -29,15 +38,12 @@ class ScaleKernel: CusomKernel {
public let net: Net public let net: Net
let device: MTLDevice? let device: MTLDevice?
let numel: Int let numel: Int
let meansNumber: [NSNumber]
// dims num nchw /// 初始化函数
let dimsNum: [NSNumber] ///
/** /// - Parameters:
* inNet: 需要运行的网络 /// - inNet: 传入自定义的网络
* commandQueue: GPU 是需要传入 /// - commandQueue: commandQueue
* inPlatform: 需要使用的平台, GPU or CPU
*/
@objc public init(inNet: Net, commandQueue: MTLCommandQueue?) { @objc public init(inNet: Net, commandQueue: MTLCommandQueue?) {
guard inNet.inputDim.cout() == 4 else { guard inNet.inputDim.cout() == 4 else {
fatalError(" input dim count must 4 ") fatalError(" input dim count must 4 ")
...@@ -49,21 +55,12 @@ class ScaleKernel: CusomKernel { ...@@ -49,21 +55,12 @@ class ScaleKernel: CusomKernel {
if let inDevice = device { if let inDevice = device {
textureLoader = MTKTextureLoader.init(device: inDevice) textureLoader = MTKTextureLoader.init(device: inDevice)
} }
numel = net.inputDim.numel() numel = net.inputDim.numel()
meansNumber = net.means.map {
NSNumber.init(value: $0)
}
dimsNum = [NSNumber.init(value: net.inputDim[0]),
NSNumber.init(value: net.inputDim[3]),
NSNumber.init(value: net.inputDim[1]),
NSNumber.init(value: net.inputDim[2])]
} }
/** /// load 模型, 返回 true 可进行预测
* load 模型, 返回 true 可进行预测 ///
*/ /// - Returns: load 成功或失败
@objc public func load() -> Bool { @objc public func load() -> Bool {
guard let inDevice = device, let inQueue = queue else { guard let inDevice = device, let inQueue = queue else {
print(" paddle mobile gpu load error, need MTLCommandQueue") print(" paddle mobile gpu load error, need MTLCommandQueue")
...@@ -71,10 +68,24 @@ class ScaleKernel: CusomKernel { ...@@ -71,10 +68,24 @@ class ScaleKernel: CusomKernel {
} }
let loader = Loader<Float32>.init() let loader = Loader<Float32>.init()
do { do {
// program = try loader.load(device: inDevice, paramPointer: net.paramPointer!, paramSize: net.paramSize,modePointer:net.modelPointer!,modelSize:net.modelSize)
program = try loader.load(device: inDevice, modelPath: net.modelPath, paraPath: net.paramPath)
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!) if let inParamPointer = net.paramPointer, let inModelPointer = net.modelPointer {
guard net.paramSize > 0 && net.modelSize > 0 else {
print(" load from memory param size or model size can't 0 ")
return false
}
program = try loader.load(device: inDevice, paramPointer: inParamPointer, paramSize: net.paramSize,modePointer:inModelPointer,modelSize:net.modelSize)
} else if let inModelPath = net.modelPath, let inParamPath = net.paramPath {
program = try loader.load(device: inDevice, modelPath: inModelPath, paraPath: inParamPath)
} else {
print(" model pointer or model file path need be specified")
return false
}
let initContext: InitContext = InitContext.init()
initContext.metalLoadMode = net.metalLoadMode
initContext.metalLibPath = net.metalLibPath
executor = try Executor<Float32>.init(inDevice: inDevice, inQueue: inQueue, inProgram: program!, initContext: initContext)
net.updateProgram(program: program!) net.updateProgram(program: program!)
} catch let error { } catch let error {
print(error) print(error)
...@@ -83,13 +94,12 @@ class ScaleKernel: CusomKernel { ...@@ -83,13 +94,12 @@ class ScaleKernel: CusomKernel {
return true return true
} }
/** /// 预测
* GPU 版本 predict ///
* texture: 需要预测的 texture 需要做过预处理 /// - Parameters:
* ( _ success: Bool, _ time:TimeInterval, _ resultArray: [Float32]) -> Void : 回调闭包, 三个参数分别为: 是否成功, 预测耗时, 结果数组 /// - texture: 输入 texture 需要使用 getTexture 获得
*/ /// - completion: 结果回调, 当 success 为 true 时 result 不为 nil
@objc public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ result: ResultHolder?) -> Void) { @objc public func predict(texture: MTLTexture, completion: @escaping ( _ success: Bool, _ result: ResultHolder?) -> Void) {
net.updateProgram(program: program!)
do { do {
try self.executor?.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (res) in try self.executor?.predict(input: texture, dim: self.net.inputDim, completionHandle: { [weak self] (res) in
guard let SSelf = self else { guard let SSelf = self else {
...@@ -105,23 +115,28 @@ class ScaleKernel: CusomKernel { ...@@ -105,23 +115,28 @@ class ScaleKernel: CusomKernel {
} }
} }
/* /// 清理内存, 调用此函数后, 不能再使用, 需重新 load
* 清理内存, 调用此函数后, 不能再使用, 需重新 load
*/
@objc public func clear() { @objc public func clear() {
executor?.clear() executor?.clear()
executor = nil executor = nil
program = nil program = nil
} }
/* /// 获取 texture, 对 texture 进行预处理, 预测时使用
* 获取 texture, 对 texture 进行预处理, GPU 预测时使用 ///
*/ /// - Parameters:
/// - image: 输入图像
/// - getTexture: 获取 texture 回调
@objc public func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) { @objc public func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) {
let texture = try? textureLoader?.newTexture(cgImage: image, options: [:]) ?! " texture loader error" let texture = try? textureLoader?.newTexture(cgImage: image, options: [:]) ?! " texture loader error"
scaleTexture(input: texture!, complete: getTexture) scaleTexture(input: texture!, complete: getTexture)
} }
/// 通过 buffer 获取 texture, 内部会使用GPU进行转换操作
///
/// - Parameters:
/// - inBuffer: 输入buffer
/// - getTexture: 结果回调
@objc public func getTexture(inBuffer: MTLBuffer, getTexture: @escaping (MTLTexture) -> Void) { @objc public func getTexture(inBuffer: MTLBuffer, getTexture: @escaping (MTLTexture) -> Void) {
guard let inQueue = queue, let inDevice = device else { guard let inQueue = queue, let inDevice = device else {
fatalError( " queue or devcie nil " ) fatalError( " queue or devcie nil " )
...@@ -131,7 +146,7 @@ class ScaleKernel: CusomKernel { ...@@ -131,7 +146,7 @@ class ScaleKernel: CusomKernel {
fatalError( " make buffer error" ) fatalError( " make buffer error" )
} }
let bufferToTextureKernel = BufferToTextureKernel.init(device: inDevice, outputDim: Shape.init(inWidth: net.inputDim[2], inHeight: net.inputDim[1], inChannel: net.inputDim[3])) let bufferToTextureKernel = BufferToTextureKernel.init(device: inDevice, outputDim: Shape.init(inWidth: net.inputDim[2], inHeight: net.inputDim[1], inChannel: net.inputDim[3]), metalLoadMode: net.metalLoadMode, metalLibPath: net.metalLibPath)
do { do {
try bufferToTextureKernel.compute(inputBuffer: inBuffer, commandBuffer: buffer) try bufferToTextureKernel.compute(inputBuffer: inBuffer, commandBuffer: buffer)
} catch { } catch {
...@@ -144,6 +159,19 @@ class ScaleKernel: CusomKernel { ...@@ -144,6 +159,19 @@ class ScaleKernel: CusomKernel {
buffer.commit() buffer.commit()
} }
/// 更新输入维度, 针对可变长输入模型
///
/// - Parameter inDim: 输入维度
@objc public func updateInputDim(inDim: Dim) {
if net.inputDim != inDim {
guard let inProgram = program else {
fatalError(" need load first ")
}
net.inputDim = inDim
net.updateProgram(program: inProgram)
}
}
public func scaleTexture(input: MTLTexture , complete: @escaping (MTLTexture) -> Void) { public func scaleTexture(input: MTLTexture , complete: @escaping (MTLTexture) -> Void) {
...@@ -155,7 +183,7 @@ class ScaleKernel: CusomKernel { ...@@ -155,7 +183,7 @@ class ScaleKernel: CusomKernel {
fatalError( " make buffer error" ) fatalError( " make buffer error" )
} }
let scaleKernel = ScaleKernel.init(device: inDevice, shape: Shape.init(inWidth: net.inputDim[2], inHeight: net.inputDim[1], inChannel: 3)) let scaleKernel = ScaleKernel.init(device: inDevice, shape: Shape.init(inWidth: net.inputDim[2], inHeight: net.inputDim[1], inChannel: 3), metalLoadMode: net.metalLoadMode, metalLibPath: net.metalLibPath)
do { do {
try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer) try scaleKernel.compute(inputTexuture: input, commandBuffer: buffer)
......
//
// MobileNetCombined.swift
// paddle-mobile
//
// Created by Xiao,Haichun on 2018/12/5.
// Copyright © 2018 orange. All rights reserved.
//
import Foundation
public class MobileNetCombined: Net {
@objc public override init(device: MTLDevice) {
super.init(device: device)
means = [0, 0, 0]
scale = 1
except = 0
modelPath = Bundle.main.path(forResource: "combined_mobilenet_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "combined_mobilenet_params", ofType: nil) ?! "para null"
modelDir = ""
inputDim_ = Dim.init(inDim: [1, 224, 224, 3])
}
@objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize)
means = [0, 0, 0]
scale = 1
except = 0
modelPath = ""
paramPath = ""
modelDir = ""
inputDim_ = Dim.init(inDim: [1, 224, 224, 3])
}
// class GenetPreProccess: CusomKernel {
// init(device: MTLDevice) {
// let s = CusomKernel.Shape.init(inWidth: 128, inHeight: 128, inChannel: 3)
// super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, usePaddleMobileLib: false)
// }
// }
override public func resultStr(res: ResultHolder) -> String {
// fatalError()
return " \(res.result[0]) ... "
}
}
//
// YoloNet.swift
// paddle-mobile
//
// Created by Xiao,Haichun on 2018/12/5.
// Copyright © 2018 orange. All rights reserved.
//
import Foundation
import Metal
public class YoloNet: Net {
@objc public override init(device: MTLDevice) {
super.init(device: device)
means = [0, 0, 0]
scale = 1
except = 0
modelPath = Bundle.main.path(forResource: "yolo_model", ofType: nil) ?! "model null"
paramPath = Bundle.main.path(forResource: "yolo_params", ofType: nil) ?! "para null"
modelDir = ""
// preprocessKernel = GenetPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 416, 416, 3])
}
@objc override public init(device: MTLDevice,paramPointer: UnsafeMutableRawPointer, paramSize:Int, modePointer: UnsafeMutableRawPointer, modelSize: Int) {
super.init(device:device,paramPointer:paramPointer,paramSize:paramSize,modePointer:modePointer,modelSize:modelSize)
means = [0, 0, 0]
scale = 1
except = 0
modelPath = ""
paramPath = ""
modelDir = ""
//preprocessKernel = GenetPreProccess.init(device: device)
inputDim_ = Dim.init(inDim: [1, 416, 416, 3])
}
// class GenetPreProccess: CusomKernel {
// init(device: MTLDevice) {
// let s = CusomKernel.Shape.init(inWidth: 128, inHeight: 128, inChannel: 3)
// super.init(device: device, inFunctionName: "genet_preprocess", outputDim: s, usePaddleMobileLib: false)
// }
// }
override public func resultStr(res: ResultHolder) -> String {
return " \(res.result[0]) ... "
}
}
...@@ -18,6 +18,7 @@ import CoreMedia ...@@ -18,6 +18,7 @@ import CoreMedia
fileprivate var defaultMetalLibrary: MTLLibrary? fileprivate var defaultMetalLibrary: MTLLibrary?
fileprivate var paddleMobileMetalLibrary: MTLLibrary? fileprivate var paddleMobileMetalLibrary: MTLLibrary?
fileprivate var customMetalLibrary: MTLLibrary?
extension MTLDevice { extension MTLDevice {
func defaultLibrary() -> MTLLibrary { func defaultLibrary() -> MTLLibrary {
...@@ -31,6 +32,22 @@ extension MTLDevice { ...@@ -31,6 +32,22 @@ extension MTLDevice {
} }
} }
func customLibrary(metalLibPath: String) -> MTLLibrary {
if customMetalLibrary == nil {
do {
customMetalLibrary = try makeLibrary(filepath: metalLibPath)
} catch let error {
fatalError("\(error)")
}
}
if let inMetalLib = customMetalLibrary {
return inMetalLib
} else {
fatalError(" customlib is nil ")
}
}
func paddleMobileLibrary() -> MTLLibrary { func paddleMobileLibrary() -> MTLLibrary {
if paddleMobileMetalLibrary == nil { if paddleMobileMetalLibrary == nil {
guard let path = Bundle.init(for: Kernel.self).path(forResource: "default", ofType: "metallib") else { guard let path = Bundle.init(for: Kernel.self).path(forResource: "default", ofType: "metallib") else {
...@@ -50,8 +67,19 @@ extension MTLDevice { ...@@ -50,8 +67,19 @@ extension MTLDevice {
} }
} }
func pipeLine(funcName: String, inPaddleMobileLib: Bool = true) -> MTLComputePipelineState { func pipeLine(funcName: String, metalLoadMode: MetalLoadMode, metalLibPath: String?) -> MTLComputePipelineState {
let useLib = inPaddleMobileLib ? paddleMobileLibrary() : defaultLibrary() let useLib: MTLLibrary
switch metalLoadMode {
case .LoadMetalInDefaultLib:
useLib = defaultLibrary()
case .LoadMetalInPaddleMobile:
useLib = paddleMobileLibrary()
case .LoadMetalInCustomMetalLib:
useLib = customLibrary(metalLibPath: metalLibPath ?! " can't be nil ")
default:
fatalError()
}
guard let function = useLib.makeFunction(name: funcName) else { guard let function = useLib.makeFunction(name: funcName) else {
fatalError(" function " + funcName + " not found") fatalError(" function " + funcName + " not found")
} }
......
...@@ -324,9 +324,10 @@ public class PaddleMobileUnitTest { ...@@ -324,9 +324,10 @@ public class PaddleMobileUnitTest {
let param = ConvAddBatchNormReluTestParam.init(inInputTexture: inputeTexture, inOutputTexture: outputTexture, inMetalParam: metalParam, inFilterBuffer: filterBuffer, inBiaseBuffer: biaseBuffer, inNewScaleBuffer: newScalueBuffer, inNewBiaseBuffer: newBiaseBuffer, inFilterSize: filterSize) let param = ConvAddBatchNormReluTestParam.init(inInputTexture: inputeTexture, inOutputTexture: outputTexture, inMetalParam: metalParam, inFilterBuffer: filterBuffer, inBiaseBuffer: biaseBuffer, inNewScaleBuffer: newScalueBuffer, inNewBiaseBuffer: newBiaseBuffer, inFilterSize: filterSize)
let initContext = InitContext.init()
initContext.metalLoadMode = .LoadMetalInDefaultLib
let convAddBnReluKernel = ConvAddBatchNormReluKernel<Float32>.init(device: device, testParam: param, initContext: initContext)
let convAddBnReluKernel = ConvAddBatchNormReluKernel<Float32>.init(device: device, testParam: param)
convAddBnReluKernel.test(commandBuffer: buffer, param: param) convAddBnReluKernel.test(commandBuffer: buffer, param: param)
......
...@@ -252,11 +252,11 @@ extension InputTexture: Variant { ...@@ -252,11 +252,11 @@ extension InputTexture: Variant {
extension MTLTexture where Self: Variant { extension MTLTexture where Self: Variant {
} }
class FetchHolder: Variant { public class FetchHolder: Variant {
var resultBuffer: MTLBuffer? var resultBuffer: MTLBuffer?
var dim: Dim public var dim: Dim
var capacity: Int public var capacity: Int
var paddedCapacity: Int public var paddedCapacity: Int
init(inPaddedCapacity: Int, inDim: Dim) { init(inPaddedCapacity: Int, inDim: Dim) {
paddedCapacity = inPaddedCapacity paddedCapacity = inPaddedCapacity
...@@ -264,7 +264,7 @@ class FetchHolder: Variant { ...@@ -264,7 +264,7 @@ class FetchHolder: Variant {
dim = inDim dim = inDim
} }
func initBuffer(device: MTLDevice) { public func initBuffer(device: MTLDevice) {
resultBuffer = device.makeBuffer(length: paddedCapacity * 4, options: []) resultBuffer = device.makeBuffer(length: paddedCapacity * 4, options: [])
} }
...@@ -278,12 +278,12 @@ class FetchHolder: Variant { ...@@ -278,12 +278,12 @@ class FetchHolder: Variant {
} }
extension FetchHolder: CustomStringConvertible, CustomDebugStringConvertible { extension FetchHolder: CustomStringConvertible, CustomDebugStringConvertible {
var description: String { public var description: String {
fatalError() fatalError()
// return "\(result)" // return "\(result)"
} }
var debugDescription: String { public var debugDescription: String {
fatalError() fatalError()
// return "\(result)" // return "\(result)"
} }
......
...@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> { ...@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
} }
} }
func creat(device: MTLDevice, opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { func creat(device: MTLDevice, opDesc: OpDesc, scope: Scope, initContext: InitContext) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else { guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
} }
do { do {
return try opCreator(device, opDesc, scope) return try opCreator(device, opDesc, scope, initContext)
} catch let error { } catch let error {
throw error throw error
} }
} }
let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] = let opCreators: [String : (MTLDevice, OpDesc, Scope, InitContext) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat, [gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat, gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat, gReluType : ReluOp<P>.creat,
......
...@@ -31,7 +31,7 @@ protocol Runable { ...@@ -31,7 +31,7 @@ protocol Runable {
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
func delogOutput() func delogOutput()
func inputVariant() -> [String : [Variant]] func inputVariant() -> [String : [MTLBuffer]]
func computeMiddleResult(device: MTLDevice, buffer: MTLCommandBuffer) func computeMiddleResult(device: MTLDevice, buffer: MTLCommandBuffer)
} }
...@@ -44,7 +44,7 @@ extension Runable where Self: OperatorProtocol{ ...@@ -44,7 +44,7 @@ extension Runable where Self: OperatorProtocol{
} }
} }
func inputVariant() -> [String : [Variant]] { func inputVariant() -> [String : [MTLBuffer]] {
// return [:] // return [:]
fatalError(" op \(type) need implement inputVariant") fatalError(" op \(type) need implement inputVariant")
} }
...@@ -59,15 +59,26 @@ extension Runable where Self: OperatorProtocol{ ...@@ -59,15 +59,26 @@ extension Runable where Self: OperatorProtocol{
} }
} }
public class InitContext {
/// metal 代码加载方式
var metalLoadMode: MetalLoadMode = .LoadMetalInDefaultLib
/// 当 metalLoadMode 为 LoadMetalInCustomMetalLib 时, metal library 路径不能为空
var metalLibPath: String? = nil
init() {
metalLoadMode = .LoadMetalInDefaultLib
metalLibPath = nil
}
}
protocol Creator where Self: OperatorProtocol{ protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol & Runable & InferShaperable associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope, initContext: InitContext) throws -> OpType
} }
extension Creator where Self: OperatorProtocol { extension Creator where Self: OperatorProtocol {
static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType { static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope, initContext: InitContext) throws -> OpType {
do { do {
return try OpType.provide(device:device, opDesc: opDesc, inScope: inScope) return try OpType.provide(device:device, opDesc: opDesc, inScope: inScope, initContext: initContext)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -89,13 +100,13 @@ protocol OperatorProtocol { ...@@ -89,13 +100,13 @@ protocol OperatorProtocol {
var attrs: [String : Attr] { get } var attrs: [String : Attr] { get }
var para: ParamType { get } var para: ParamType { get }
var kernel: KerType { get } var kernel: KerType { get }
init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws init(device: MTLDevice, opDesc: OpDesc, inScope: Scope, initContext: InitContext) throws
} }
extension OperatorProtocol { extension OperatorProtocol {
static func provide(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> Self { static func provide(device: MTLDevice, opDesc: OpDesc, inScope: Scope, initContext: InitContext) throws -> Self {
do { do {
return try Self.init(device: device, opDesc: opDesc, inScope: inScope) return try Self.init(device: device, opDesc: opDesc, inScope: inScope, initContext: initContext)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -103,18 +114,7 @@ extension OperatorProtocol { ...@@ -103,18 +114,7 @@ extension OperatorProtocol {
} }
class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where KernelType.ParamType == ParameterType { class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where KernelType.ParamType == ParameterType {
typealias ParamType = ParameterType required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope, initContext: InitContext) throws {
typealias KerType = KernelType
let type: String
let inputs: [String : [String]]
var paraInputs: [String : [String]]
let outpus: [String : [String]]
let attrs: [String : Attr]
let para: ParamType
let scope: Scope
var kernel: KerType
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
// print("create op: \(opDesc.type)")
type = opDesc.type type = opDesc.type
scope = inScope scope = inScope
inputs = opDesc.inputs inputs = opDesc.inputs
...@@ -126,8 +126,19 @@ class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where ...@@ -126,8 +126,19 @@ class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where
} catch let error { } catch let error {
throw error throw error
} }
kernel = KernelType.init(device: device, param: para) kernel = KernelType.init(device: device, param: para, initContext: initContext)
} }
typealias ParamType = ParameterType
typealias KerType = KernelType
let type: String
let inputs: [String : [String]]
var paraInputs: [String : [String]]
let outpus: [String : [String]]
let attrs: [String : Attr]
let para: ParamType
let scope: Scope
var kernel: KerType
} }
// op infos // op infos
......
...@@ -111,6 +111,7 @@ class ConvAddOp<P: PrecisionType>: Operator<ConvAddKernel<P>, ConvAddParam<P>>, ...@@ -111,6 +111,7 @@ class ConvAddOp<P: PrecisionType>: Operator<ConvAddKernel<P>, ConvAddParam<P>>,
// print(biase) // print(biase)
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
} }
} }
...@@ -17,14 +17,6 @@ import Foundation ...@@ -17,14 +17,6 @@ import Foundation
class DepthConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runable, Creator, InferShaperable { class DepthConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = DepthConvOp<P> typealias OpType = DepthConvOp<P>
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
do {
try super.init(device: device, opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
}
}
func inferShape() { func inferShape() {
let inDims = para.input.dim let inDims = para.input.dim
......
...@@ -32,7 +32,7 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -32,7 +32,7 @@ class ElementwiseAddParam<P: PrecisionType>: OpParam {
let device = inputX.metalTexture!.device let device = inputX.metalTexture!.device
inputY = Texture.init(device: device, inDim: tensorY.dim) inputY = Texture.init(device: device, inDim: tensorY.dim)
let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel())) let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel()))
inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims, transpose: [0, 1, 2, 3], inComputePrecision: computePrecision) inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims, transpose: [0, 1, 2, 3], inComputePrecision: GlobalConfig.shared.computePrecision)
} }
// required init(device: MTLDevice, param: ElementwiseAddParam<P>) { // required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
......
...@@ -34,7 +34,7 @@ class ElementwiseAddPreluParam<P: PrecisionType>: OpParam { ...@@ -34,7 +34,7 @@ class ElementwiseAddPreluParam<P: PrecisionType>: OpParam {
let device = inputX.metalTexture!.device let device = inputX.metalTexture!.device
inputY = Texture.init(device: device, inDim: tensorY.dim) inputY = Texture.init(device: device, inDim: tensorY.dim)
let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel())) let value: [P] = Array(UnsafeBufferPointer(start: tensorY.data.pointer, count: tensorY.dim.numel()))
inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims, transpose: [0, 1, 2, 3], inComputePrecision: computePrecision) inputY.metalTexture = device.tensor2texture(value: value, dim: tensorY.dim.dims, transpose: [0, 1, 2, 3], inComputePrecision: GlobalConfig.shared.computePrecision)
} }
// required init(device: MTLDevice, param: ElementwiseAddParam<P>) { // required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
......
...@@ -63,7 +63,7 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam< ...@@ -63,7 +63,7 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam<
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[3], h: para.output.padToFourDim[2], w: para.output.padToFourDim[1])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[3], h: para.output.padToFourDim[2], w: para.output.padToFourDim[1])).strideArray())
} }
} }
......
...@@ -45,17 +45,17 @@ class FetchKernel<P: PrecisionType>: Kernel, Computable { ...@@ -45,17 +45,17 @@ class FetchKernel<P: PrecisionType>: Kernel, Computable {
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: FetchParam<P>) { required init(device: MTLDevice, param: FetchParam<P>, initContext: InitContext) {
param.output.initBuffer(device: device) param.output.initBuffer(device: device)
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.input.transpose == [0, 2, 3, 1] { if param.input.transpose == [0, 2, 3, 1] {
super.init(device: device, inFunctionName: "fetch_half") super.init(device: device, inFunctionName: "fetch_half", initContext: initContext)
} else { } else {
fatalError(" not support ") fatalError(" not support ")
} }
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.input.transpose == [0, 2, 3, 1] { if param.input.transpose == [0, 2, 3, 1] {
super.init(device: device, inFunctionName: "fetch_float") super.init(device: device, inFunctionName: "fetch_float", initContext: initContext)
} else { } else {
fatalError(" not support ") fatalError(" not support ")
} }
......
...@@ -21,14 +21,14 @@ public protocol TestParam { ...@@ -21,14 +21,14 @@ public protocol TestParam {
public protocol Testable { public protocol Testable {
associatedtype TestParamType: TestParam associatedtype TestParamType: TestParam
func test(commandBuffer: MTLCommandBuffer, param: TestParamType) func test(commandBuffer: MTLCommandBuffer, param: TestParamType)
init(device: MTLDevice, testParam: TestParamType) init(device: MTLDevice, testParam: TestParamType, initContext: InitContext)
} }
protocol Computable { protocol Computable {
associatedtype ParamType: OpParam associatedtype ParamType: OpParam
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice, param: ParamType) init(device: MTLDevice, param: ParamType, initContext: InitContext)
} }
protocol KernelProtocol { protocol KernelProtocol {
...@@ -37,20 +37,20 @@ protocol KernelProtocol { ...@@ -37,20 +37,20 @@ protocol KernelProtocol {
} }
open class Kernel { @objc open class Kernel: NSObject{
let pipline: MTLComputePipelineState let pipline: MTLComputePipelineState
let functionName: String let functionName: String
public init(device: MTLDevice, inFunctionName: String, usePaddleMobileLib: Bool = true) { public init(device: MTLDevice, inFunctionName: String, usePaddleMobileLib: Bool = false, initContext: InitContext) {
pipline = device.pipeLine(funcName: inFunctionName, inPaddleMobileLib: usePaddleMobileLib) pipline = device.pipeLine(funcName: inFunctionName, metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
functionName = inFunctionName functionName = inFunctionName
} }
} }
public struct Shape { @objc public class Shape: NSObject {
public let width: Int public let width: Int
public let height: Int public let height: Int
public let channel: Int public let channel: Int
public init(inWidth: Int, inHeight: Int, inChannel: Int){ @objc public init(inWidth: Int, inHeight: Int, inChannel: Int){
width = inWidth width = inWidth
height = inHeight height = inHeight
channel = inChannel channel = inChannel
...@@ -60,16 +60,16 @@ public struct Shape { ...@@ -60,16 +60,16 @@ public struct Shape {
open class BufferToTextureKernel: Kernel { open class BufferToTextureKernel: Kernel {
public let outputTexture: MTLTexture public let outputTexture: MTLTexture
public init(device: MTLDevice, outputDim: Shape, usePaddleMobileLib: Bool = false) { public init(device: MTLDevice, outputDim: Shape, metalLoadMode: MetalLoadMode, metalLibPath: String?) {
let textureDesc = MTLTextureDescriptor.init() let textureDesc = MTLTextureDescriptor.init()
textureDesc.textureType = .type2D textureDesc.textureType = .type2D
textureDesc.width = outputDim.width textureDesc.width = outputDim.width
textureDesc.height = outputDim.height textureDesc.height = outputDim.height
textureDesc.depth = (outputDim.channel + 3) / 4 textureDesc.depth = (outputDim.channel + 3) / 4
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
textureDesc.pixelFormat = .rgba16Float textureDesc.pixelFormat = .rgba16Float
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
textureDesc.pixelFormat = .rgba32Float textureDesc.pixelFormat = .rgba32Float
} else { } else {
fatalError() fatalError()
...@@ -78,10 +78,13 @@ open class BufferToTextureKernel: Kernel { ...@@ -78,10 +78,13 @@ open class BufferToTextureKernel: Kernel {
textureDesc.usage = [.shaderRead, .shaderWrite] textureDesc.usage = [.shaderRead, .shaderWrite]
textureDesc.storageMode = .shared textureDesc.storageMode = .shared
outputTexture = device.makeTexture(descriptor: textureDesc) ?! " make texture error " outputTexture = device.makeTexture(descriptor: textureDesc) ?! " make texture error "
if computePrecision == .Float32 { let initContext = InitContext.init()
super.init(device: device, inFunctionName: "buffer_to_texture_kernel", usePaddleMobileLib: usePaddleMobileLib) initContext.metalLibPath = metalLibPath
initContext.metalLoadMode = metalLoadMode
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "buffer_to_texture_kernel", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "buffer_to_texture_kernel_half", usePaddleMobileLib: usePaddleMobileLib) super.init(device: device, inFunctionName: "buffer_to_texture_kernel_half", initContext: initContext)
} }
} }
...@@ -98,19 +101,19 @@ open class BufferToTextureKernel: Kernel { ...@@ -98,19 +101,19 @@ open class BufferToTextureKernel: Kernel {
} }
open class CusomKernel: Kernel { @objc open class CusomKernel: Kernel {
public let outputTexture: MTLTexture public let outputTexture: MTLTexture
public init(device: MTLDevice, inFunctionName: String, outputDim: Shape, usePaddleMobileLib: Bool = false) { public init(device: MTLDevice, inFunctionName: String, outputDim: Shape, metalLoadModel: MetalLoadMode, metalLibPath: String?) {
let textureDesc = MTLTextureDescriptor.init() let textureDesc = MTLTextureDescriptor.init()
textureDesc.textureType = .type2D textureDesc.textureType = .type2D
textureDesc.width = outputDim.width textureDesc.width = outputDim.width
textureDesc.height = outputDim.height textureDesc.height = outputDim.height
textureDesc.depth = (outputDim.channel + 3) / 4 textureDesc.depth = (outputDim.channel + 3) / 4
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
textureDesc.pixelFormat = .rgba16Float textureDesc.pixelFormat = .rgba16Float
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
textureDesc.pixelFormat = .rgba32Float textureDesc.pixelFormat = .rgba32Float
} else { } else {
fatalError() fatalError()
...@@ -120,7 +123,10 @@ open class CusomKernel: Kernel { ...@@ -120,7 +123,10 @@ open class CusomKernel: Kernel {
textureDesc.storageMode = .shared textureDesc.storageMode = .shared
outputTexture = device.makeTexture(descriptor: textureDesc) ?! " make texture error " outputTexture = device.makeTexture(descriptor: textureDesc) ?! " make texture error "
super.init(device: device, inFunctionName: inFunctionName, usePaddleMobileLib: usePaddleMobileLib) let context = InitContext.init()
context.metalLoadMode = metalLoadModel
context.metalLibPath = metalLibPath
super.init(device: device, inFunctionName: inFunctionName, initContext: context)
} }
public func compute(inputTexuture: MTLTexture, commandBuffer: MTLCommandBuffer) throws { public func compute(inputTexuture: MTLTexture, commandBuffer: MTLCommandBuffer) throws {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable { class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: BatchNormParam<P>) { required init(device: MTLDevice, param: BatchNormParam<P>, initContext: InitContext) {
let count = param.variance.dim.numel() let count = param.variance.dim.numel()
let varianceP = param.variance.data.pointer let varianceP = param.variance.data.pointer
let meanP = param.mean.data.pointer let meanP = param.mean.data.pointer
...@@ -27,13 +27,13 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable { ...@@ -27,13 +27,13 @@ class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
scaleP[i] = invStd * scaleP[i] scaleP[i] = invStd * scaleP[i]
} }
param.bias.initBuffer(device: device, precision: computePrecision) param.bias.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.scale.initBuffer(device: device, precision: computePrecision) param.scale.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "batchnorm") super.init(device: device, inFunctionName: "batchnorm", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "batchnorm_half") super.init(device: device, inFunctionName: "batchnorm_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -41,12 +41,12 @@ class BilinearInterpKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -41,12 +41,12 @@ class BilinearInterpKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: BilinearInterpParam<P>) { required init(device: MTLDevice, param: BilinearInterpParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "bilinear_interp_float") super.init(device: device, inFunctionName: "bilinear_interp_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "bilinear_interp_half") super.init(device: device, inFunctionName: "bilinear_interp_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -32,12 +32,12 @@ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -32,12 +32,12 @@ class BoxcoderKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: BoxcoderParam<P>) { required init(device: MTLDevice, param: BoxcoderParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 3, 1, 2], computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "boxcoder_float") super.init(device: device, inFunctionName: "boxcoder_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "boxcoder_half") super.init(device: device, inFunctionName: "boxcoder_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -52,8 +52,8 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -52,8 +52,8 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: ConcatParam<P>) { required init(device: MTLDevice, param: ConcatParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.transpose, computePrecision: GlobalConfig.shared.computePrecision)
let orank = param.output.tensorDim.cout() let orank = param.output.tensorDim.cout()
let num = param.input.count let num = param.input.count
assert(num <= 6) assert(num <= 6)
...@@ -133,16 +133,16 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -133,16 +133,16 @@ class ConcatKernel<P: PrecisionType>: Kernel, Computable{
} }
} }
pm.vdim = (Int32(vdim[0]), Int32(vdim[1]), Int32(vdim[2]), Int32(vdim[3]), Int32(vdim[4]), Int32(vdim[5])) pm.vdim = (Int32(vdim[0]), Int32(vdim[1]), Int32(vdim[2]), Int32(vdim[3]), Int32(vdim[4]), Int32(vdim[5]))
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_float") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_half") super.init(device: device, inFunctionName: "concat_\(orank)_\(num)_\(v)_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
} }
required init(device: MTLDevice, testParam: ConcatTestParam) { required init(device: MTLDevice, testParam: ConcatTestParam, initContext: InitContext) {
super.init(device: device, inFunctionName: "concat") super.init(device: device, inFunctionName: "concat", initContext: initContext)
} }
} }
...@@ -16,99 +16,99 @@ import Foundation ...@@ -16,99 +16,99 @@ import Foundation
class ConvAddAddPreluKernel<P: PrecisionType>: Kernel, Computable { class ConvAddAddPreluKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddAddPreluParam<P>) { required init(device: MTLDevice, param: ConvAddAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_half", initContext: initContext)
} }
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_half", initContext: initContext)
} }
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
} }
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_float", initContext: initContext)
} }
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_float", initContext: initContext)
} }
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
......
...@@ -37,44 +37,44 @@ struct ConvAddBatchNormReluTestParam: TestParam { ...@@ -37,44 +37,44 @@ struct ConvAddBatchNormReluTestParam: TestParam {
} }
class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable, Testable { class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
required init(device: MTLDevice, testParam: ConvAddBatchNormReluTestParam) { required init(device: MTLDevice, testParam: ConvAddBatchNormReluTestParam, initContext: InitContext) {
if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 { if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1", initContext: initContext)
} else if testParam.filterSize.channel == 1 { } else if testParam.filterSize.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3", initContext: initContext)
} }
} }
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>) { required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32) param.variance.initBuffer(device: device, precision: .Float32)
param.mean.initBuffer(device: device, precision: .Float32) param.mean.initBuffer(device: device, precision: .Float32)
param.scale.initBuffer(device: device, precision: .Float32) param.scale.initBuffer(device: device, precision: .Float32)
param.bias.initBuffer(device: device, precision: .Float32) param.bias.initBuffer(device: device, precision: .Float32)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1", initContext: initContext)
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3", initContext: initContext)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1_half") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_1x1_half", initContext: initContext)
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3_half") super.init(device: device, inFunctionName: "depthwise_conv_add_batch_norm_relu_3x3_half", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3_half") super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3_half", initContext: initContext)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
...@@ -120,10 +120,10 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable, Testable ...@@ -120,10 +120,10 @@ class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable, Testable
var newBiaseBuffer: MTLBuffer var newBiaseBuffer: MTLBuffer
var newScaleBuffer: MTLBuffer var newScaleBuffer: MTLBuffer
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
newBiaseBuffer = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length)! newBiaseBuffer = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length)!
newScaleBuffer = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length)! newScaleBuffer = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length)!
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
newBiaseBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)! newBiaseBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)!
newScaleBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)! newScaleBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)!
......
...@@ -16,37 +16,37 @@ import Foundation ...@@ -16,37 +16,37 @@ import Foundation
class ConvAddKernel<P: PrecisionType>: Kernel, Computable { class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddParam<P>) { required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1]) let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1])
param.filter.initBuffer(device: device, precision: computePrecision, padWhenOneC: padWhenOneC) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, padWhenOneC: padWhenOneC)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1_half") super.init(device: device, inFunctionName: "conv_add_1x1_half", initContext: initContext)
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_half", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_3x3_half") super.init(device: device, inFunctionName: "conv_add_3x3_half", initContext: initContext)
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1_half") super.init(device: device, inFunctionName: "conv_add_5x1_half", initContext: initContext)
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x5_half") super.init(device: device, inFunctionName: "conv_add_1x5_half", initContext: initContext)
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
} }
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1") super.init(device: device, inFunctionName: "conv_add_1x1", initContext: initContext)
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3", initContext: initContext)
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1") super.init(device: device, inFunctionName: "conv_add_5x1", initContext: initContext)
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x5") super.init(device: device, inFunctionName: "conv_add_1x5", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_3x3") super.init(device: device, inFunctionName: "conv_add_3x3", initContext: initContext)
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
} }
......
...@@ -16,99 +16,99 @@ import Foundation ...@@ -16,99 +16,99 @@ import Foundation
class ConvAddPreluKernel<P: PrecisionType>: Kernel, Computable { class ConvAddPreluKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddPreluParam<P>) { required init(device: MTLDevice, param: ConvAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_half", initContext: initContext)
} }
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_half", initContext: initContext)
} }
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_half") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_half", initContext: initContext)
} }
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
} }
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_1x1_prelu_other_float", initContext: initContext)
} }
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_float") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_3x3_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_5x1_prelu_other_float", initContext: initContext)
} }
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_float") super.init(device: device, inFunctionName: "conv_add_1x5_prelu_other_float", initContext: initContext)
} }
} else { } else {
fatalError(" unsupport yet ") fatalError(" unsupport yet ")
......
...@@ -38,44 +38,44 @@ struct ConvBNReluTestParam: TestParam { ...@@ -38,44 +38,44 @@ struct ConvBNReluTestParam: TestParam {
} }
class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable { class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
required init(device: MTLDevice, testParam: ConvBNReluTestParam) { required init(device: MTLDevice, testParam: ConvBNReluTestParam, initContext: InitContext) {
if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 { if testParam.filterSize.width == 1 && testParam.filterSize.height == 1 {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1", initContext: initContext)
} else if testParam.filterSize.channel == 1 { } else if testParam.filterSize.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3", initContext: initContext)
} }
} }
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvBNReluParam<P>) { required init(device: MTLDevice, param: ConvBNReluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.variance.initBuffer(device: device, precision: .Float32) param.variance.initBuffer(device: device, precision: .Float32)
param.mean.initBuffer(device: device, precision: .Float32) param.mean.initBuffer(device: device, precision: .Float32)
param.scale.initBuffer(device: device, precision: .Float32) param.scale.initBuffer(device: device, precision: .Float32)
param.bias.initBuffer(device: device, precision: .Float32) param.bias.initBuffer(device: device, precision: .Float32)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1") super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1", initContext: initContext)
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3") super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3", initContext: initContext)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1_half") super.init(device: device, inFunctionName: "conv_batch_norm_relu_1x1_half", initContext: initContext)
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3_half") super.init(device: device, inFunctionName: "depthwise_conv_batch_norm_relu_3x3_half", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3_half") super.init(device: device, inFunctionName: "conv_batch_norm_relu_3x3_half", initContext: initContext)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
...@@ -122,10 +122,10 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable { ...@@ -122,10 +122,10 @@ class ConvBNReluKernel<P: PrecisionType>: Kernel, Computable, Testable {
var newBiaseBuffer: MTLBuffer var newBiaseBuffer: MTLBuffer
var newScaleBuffer: MTLBuffer var newScaleBuffer: MTLBuffer
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
newBiaseBuffer = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length)! newBiaseBuffer = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length)!
newScaleBuffer = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length)! newScaleBuffer = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length)!
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
newBiaseBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)! newBiaseBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)!
newScaleBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)! newScaleBuffer = device.makeBuffer(length: param.bias.buffer.length / 2)!
......
...@@ -26,14 +26,14 @@ public struct MetalConvParam { ...@@ -26,14 +26,14 @@ public struct MetalConvParam {
class ConvKernel<P: PrecisionType>: Kernel, Computable { class ConvKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam<P>) { required init(device: MTLDevice, param: ConvParam<P>, initContext: InitContext) {
param.filter.initBuffer(device: device, precision: ComputePrecision.Float32) param.filter.initBuffer(device: device, precision: ComputePrecision.Float32)
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_1x1") super.init(device: device, inFunctionName: "conv_1x1", initContext: initContext)
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 {
super.init(device: device, inFunctionName: "depthwise_conv_3x3") super.init(device: device, inFunctionName: "depthwise_conv_3x3", initContext: initContext)
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_3x3") super.init(device: device, inFunctionName: "conv_3x3", initContext: initContext)
} else { } else {
fatalError(" unsupport ") fatalError(" unsupport ")
} }
......
...@@ -30,18 +30,18 @@ struct MetalConvTransposeParam { ...@@ -30,18 +30,18 @@ struct MetalConvTransposeParam {
class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{ class ConvTransposeKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: MetalConvTransposeParam! var metalParam: MetalConvTransposeParam!
required init(device: MTLDevice, param: ConvTransposeParam<P>) { required init(device: MTLDevice, param: ConvTransposeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision, convertToNHWC: false, withTranspose: true) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, convertToNHWC: false, withTranspose: true)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.stride == [2, 2] && param.stride == [2, 2] { if param.stride == [2, 2] && param.stride == [2, 2] {
super.init(device: device, inFunctionName: "conv_transpose2x2_stride2") super.init(device: device, inFunctionName: "conv_transpose2x2_stride2", initContext: initContext)
} else { } else {
fatalError(" -- conv transpose unsupported yet -- ") fatalError(" -- conv transpose unsupported yet -- ")
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.stride == [2, 2] && param.stride == [2, 2] { if param.stride == [2, 2] && param.stride == [2, 2] {
super.init(device: device, inFunctionName: "conv_transpose2x2_stride2_half") super.init(device: device, inFunctionName: "conv_transpose2x2_stride2_half", initContext: initContext)
} else { } else {
fatalError(" -- conv transpose unsupported yet -- ") fatalError(" -- conv transpose unsupported yet -- ")
} }
......
...@@ -26,8 +26,8 @@ struct ElementwiseAddMetalParam { ...@@ -26,8 +26,8 @@ struct ElementwiseAddMetalParam {
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable { class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddParam<P>) { required init(device: MTLDevice, param: ElementwiseAddParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
metalParam = ElementwiseAddMetalParam.init() metalParam = ElementwiseAddMetalParam.init()
...@@ -50,10 +50,10 @@ class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -50,10 +50,10 @@ class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
// print("===> elementwise_add fast!!!") // print("===> elementwise_add fast!!!")
metalParam.fast = 1 metalParam.fast = 1
} }
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "elementwise_add") super.init(device: device, inFunctionName: "elementwise_add", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "elementwise_add_half") super.init(device: device, inFunctionName: "elementwise_add_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -17,9 +17,9 @@ import Foundation ...@@ -17,9 +17,9 @@ import Foundation
class ElementwiseAddPreluKernel<P: PrecisionType>: Kernel, Computable { class ElementwiseAddPreluKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: ElementwiseAddMetalParam var metalParam: ElementwiseAddMetalParam
required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>) { required init(device: MTLDevice, param: ElementwiseAddPreluParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.inputX.transpose, computePrecision: GlobalConfig.shared.computePrecision)
param.alpha.initBuffer(device: device, precision: computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
metalParam = ElementwiseAddMetalParam.init() metalParam = ElementwiseAddMetalParam.init()
...@@ -43,21 +43,21 @@ class ElementwiseAddPreluKernel<P: PrecisionType>: Kernel, Computable { ...@@ -43,21 +43,21 @@ class ElementwiseAddPreluKernel<P: PrecisionType>: Kernel, Computable {
metalParam.fast = 1 metalParam.fast = 1
} }
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "elementwise_add_channel_float") super.init(device: device, inFunctionName: "elementwise_add_channel_float", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "elementwise_add_element_float") super.init(device: device, inFunctionName: "elementwise_add_element_float", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "elementwise_add_prelu_float") super.init(device: device, inFunctionName: "elementwise_add_prelu_float", initContext: initContext)
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "elementwise_add_channel_half") super.init(device: device, inFunctionName: "elementwise_add_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "elementwise_add_channel_half") super.init(device: device, inFunctionName: "elementwise_add_channel_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "elementwise_add_channel_half") super.init(device: device, inFunctionName: "elementwise_add_channel_half", initContext: initContext)
} }
} else { } else {
fatalError() fatalError()
......
...@@ -26,8 +26,8 @@ class FlattenKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -26,8 +26,8 @@ class FlattenKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: FlattenMetalParam var metalParam: FlattenMetalParam
required init(device: MTLDevice, param: FlattenParam<P>) { required init(device: MTLDevice, param: FlattenParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: computePrecision) param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
var id: [Int32] = [1, 1, 1, 1] var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() { for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i]) id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
...@@ -47,10 +47,10 @@ class FlattenKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -47,10 +47,10 @@ class FlattenKernel<P: PrecisionType>: Kernel, Computable{
let irank = param.input.tensorDim.cout() let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout() let orank = param.output.tensorDim.cout()
assert(orank == 2) assert(orank == 2)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_float") super.init(device: device, inFunctionName: "reshape_\(irank)_2_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_2_half") super.init(device: device, inFunctionName: "reshape_\(irank)_2_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -17,16 +17,16 @@ import Foundation ...@@ -17,16 +17,16 @@ import Foundation
class MulticlassNMSKernel<P: PrecisionType>: Kernel, Computable{ class MulticlassNMSKernel<P: PrecisionType>: Kernel, Computable{
let pipline1: MTLComputePipelineState let pipline1: MTLComputePipelineState
required init(device: MTLDevice, param: MulticlassNMSParam<P>) { required init(device: MTLDevice, param: MulticlassNMSParam<P>, initContext: InitContext) {
param.middleOutput.initBuffer(device: device) param.middleOutput.initBuffer(device: device)
param.bboxOutput.initBuffer(device: device) param.bboxOutput.initBuffer(device: device)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
pipline1 = device.pipeLine(funcName: "nms_fetch_bbox", inPaddleMobileLib: true) pipline1 = device.pipeLine(funcName: "nms_fetch_bbox", metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
super.init(device: device, inFunctionName: "nms_fetch_result") super.init(device: device, inFunctionName: "nms_fetch_result", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
pipline1 = device.pipeLine(funcName: "nms_fetch_bbox_half", inPaddleMobileLib: true) pipline1 = device.pipeLine(funcName: "nms_fetch_bbox_half", metalLoadMode: initContext.metalLoadMode, metalLibPath: initContext.metalLibPath)
super.init(device: device, inFunctionName: "nms_fetch_result_half") super.init(device: device, inFunctionName: "nms_fetch_result_half", initContext: initContext)
} else { } else {
fatalError( " unsupport precision " ) fatalError( " unsupport precision " )
} }
......
...@@ -26,8 +26,8 @@ struct PoolMetalParam { ...@@ -26,8 +26,8 @@ struct PoolMetalParam {
class PoolKernel<P: PrecisionType>: Kernel, Computable{ class PoolKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: PoolMetalParam var metalParam: PoolMetalParam
required init(device: MTLDevice, param: PoolParam<P>) { required init(device: MTLDevice, param: PoolParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
var poolType: Int32 var poolType: Int32
switch param.poolType { switch param.poolType {
...@@ -48,10 +48,10 @@ class PoolKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -48,10 +48,10 @@ class PoolKernel<P: PrecisionType>: Kernel, Computable{
poolType: poolType poolType: poolType
) )
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "pool_float") super.init(device: device, inFunctionName: "pool_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "pool_half") super.init(device: device, inFunctionName: "pool_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -15,24 +15,24 @@ ...@@ -15,24 +15,24 @@
import Foundation import Foundation
class PreluKernel<P: PrecisionType>: Kernel, Computable{ class PreluKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: PreluParam<P>) { required init(device: MTLDevice, param: PreluParam<P>, initContext: InitContext) {
param.alpha.initBuffer(device: device, precision: computePrecision) param.alpha.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "prelu_channel") super.init(device: device, inFunctionName: "prelu_channel", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "prelu_element") super.init(device: device, inFunctionName: "prelu_element", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "prelu_other") super.init(device: device, inFunctionName: "prelu_other", initContext: initContext)
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.mode == "channel" { if param.mode == "channel" {
super.init(device: device, inFunctionName: "prelu_channel_half") super.init(device: device, inFunctionName: "prelu_channel_half", initContext: initContext)
} else if param.mode == "element" { } else if param.mode == "element" {
super.init(device: device, inFunctionName: "prelu_element_half") super.init(device: device, inFunctionName: "prelu_element_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "prelu_other_half") super.init(device: device, inFunctionName: "prelu_other_half", initContext: initContext)
} }
} else { } else {
fatalError() fatalError()
......
...@@ -32,29 +32,28 @@ struct PriorBoxMetalParam { ...@@ -32,29 +32,28 @@ struct PriorBoxMetalParam {
class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: PriorBoxMetalParam! var metalParam: PriorBoxMetalParam!
required init(device: MTLDevice, param: PriorBoxParam<P>) { required init(device: MTLDevice, param: PriorBoxParam<P>, initContext: InitContext) {
let originDim = param.output.tensorDim; let originDim = param.output.tensorDim;
param.output.tensorDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]]) param.output.tensorDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.padToFourDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]]) param.output.padToFourDim = Dim.init(inDim: [1, originDim[0], originDim[1], originDim[2] * originDim[3]])
param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 1, 2, 3], computePrecision: GlobalConfig.shared.computePrecision)
param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: computePrecision) param.outputVariances.initTexture(device: device, inTranspose: [2, 0, 1, 3], computePrecision: GlobalConfig.shared.computePrecision)
if GlobalConfig.shared.computePrecision == .Float32 {
if computePrecision == .Float32 {
if param.min_max_aspect_ratios_order { if param.min_max_aspect_ratios_order {
super.init(device: device, inFunctionName: "prior_box_MinMaxAspectRatiosOrder") super.init(device: device, inFunctionName: "prior_box_MinMaxAspectRatiosOrder", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "prior_box") super.init(device: device, inFunctionName: "prior_box", initContext: initContext)
} }
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
if param.min_max_aspect_ratios_order { if param.min_max_aspect_ratios_order {
super.init(device: device, inFunctionName: "prior_box_MinMaxAspectRatiosOrder_half") super.init(device: device, inFunctionName: "prior_box_MinMaxAspectRatiosOrder_half", initContext: initContext)
} else { } else {
super.init(device: device, inFunctionName: "prior_box_half") super.init(device: device, inFunctionName: "prior_box_half", initContext: initContext)
} }
} else { } else {
fatalError() fatalError()
...@@ -105,12 +104,12 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -105,12 +104,12 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{
} }
} }
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
let buffer = device.makeBuffer(length: outputAspectRatior.count * MemoryLayout<Float16>.size) let buffer = device.makeBuffer(length: outputAspectRatior.count * MemoryLayout<Float16>.size)
float32ToFloat16(input: &outputAspectRatior, output:(buffer?.contents())!, count: outputAspectRatior.count) float32ToFloat16(input: &outputAspectRatior, output:(buffer?.contents())!, count: outputAspectRatior.count)
param.newAspectRatios = buffer param.newAspectRatios = buffer
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
let buffer = device.makeBuffer(bytes: outputAspectRatior, length: outputAspectRatior.count * MemoryLayout<Float32>.size, options: []) let buffer = device.makeBuffer(bytes: outputAspectRatior, length: outputAspectRatior.count * MemoryLayout<Float32>.size, options: [])
param.newAspectRatios = buffer param.newAspectRatios = buffer
} else { } else {
......
...@@ -25,11 +25,11 @@ class ReluKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -25,11 +25,11 @@ class ReluKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: ReluParam<P>) { required init(device: MTLDevice, param: ReluParam<P>, initContext: InitContext) {
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "relu") super.init(device: device, inFunctionName: "relu", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "relu_half") super.init(device: device, inFunctionName: "relu_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -31,8 +31,8 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -31,8 +31,8 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: ReshapeMetalParam var metalParam: ReshapeMetalParam
required init(device: MTLDevice, param: ReshapeParam<P>) { required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: computePrecision) param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
var id: [Int32] = [1, 1, 1, 1] var id: [Int32] = [1, 1, 1, 1]
for i in 0..<param.input.tensorDim.cout() { for i in 0..<param.input.tensorDim.cout() {
id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i]) id[4-param.input.tensorDim.cout()+i] = Int32(param.input.tensorDim[i])
...@@ -51,23 +51,23 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -51,23 +51,23 @@ class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
) )
let irank = param.input.tensorDim.cout() let irank = param.input.tensorDim.cout()
let orank = param.output.tensorDim.cout() let orank = param.output.tensorDim.cout()
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half") super.init(device: device, inFunctionName: "reshape_\(irank)_\(orank)_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
} }
required init(device: MTLDevice, testParam: ReshapeTestParam) { required init(device: MTLDevice, testParam: ReshapeTestParam, initContext: InitContext) {
metalParam = ReshapeMetalParam.init( metalParam = ReshapeMetalParam.init(
idim: (0, 0, 0, 0), idim: (0, 0, 0, 0),
itrans: (0, 0, 0, 0), itrans: (0, 0, 0, 0),
odim: (0, 0, 0, 0), odim: (0, 0, 0, 0),
otrans: (0, 0, 0, 0) otrans: (0, 0, 0, 0)
) )
super.init(device: device, inFunctionName: "reshape") super.init(device: device, inFunctionName: "reshape", initContext: initContext)
} }
func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws {
......
...@@ -20,6 +20,17 @@ struct ResizeBilinearMetalParam { ...@@ -20,6 +20,17 @@ struct ResizeBilinearMetalParam {
} }
class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{ class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: ResizeBilinearParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "resize_bilinear_half", initContext: initContext)
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: ResizeBilinearParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ResizeBilinearParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
...@@ -35,15 +46,6 @@ class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -35,15 +46,6 @@ class ResizeBilinearKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: ResizeBilinearParam<P>) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "resize_bilinear")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "resize_bilinear_half")
} else {
fatalError()
}
}
} }
//
// Scale.swift
// paddle-mobile
//
// Created by liuRuiLong on 2019/1/4.
// Copyright © 2019 orange. All rights reserved.
//
import Foundation
class ScaleKernel: CusomKernel {
init(device: MTLDevice, shape: Shape, metalLoadMode: MetalLoadMode, metalLibPath: String?) {
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "scale", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "scale_half", outputDim: shape, metalLoadModel: metalLoadMode, metalLibPath: metalLibPath)
} else {
fatalError(" unsupport ")
}
}
}
...@@ -28,12 +28,12 @@ class ShapeKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -28,12 +28,12 @@ class ShapeKernel<P: PrecisionType>: Kernel, Computable{
// encoder.endEncoding() // encoder.endEncoding()
} }
required init(device: MTLDevice, param: ShapeParam<P>) { required init(device: MTLDevice, param: ShapeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: computePrecision) param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "shape") super.init(device: device, inFunctionName: "shape", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "shape_half") super.init(device: device, inFunctionName: "shape_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -22,16 +22,16 @@ struct SoftmaxMetalParam { ...@@ -22,16 +22,16 @@ struct SoftmaxMetalParam {
class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{ class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
var metalParam: SoftmaxMetalParam var metalParam: SoftmaxMetalParam
required init(device: MTLDevice, param: SoftmaxParam<P>) { required init(device: MTLDevice, param: SoftmaxParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: computePrecision) param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
metalParam = SoftmaxMetalParam.init( metalParam = SoftmaxMetalParam.init(
N: Int32(param.input.tensorDim[0]), N: Int32(param.input.tensorDim[0]),
K: Int32(param.input.tensorDim[1]) K: Int32(param.input.tensorDim[1])
) )
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "softmax_float") super.init(device: device, inFunctionName: "softmax_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "softmax_half") super.init(device: device, inFunctionName: "softmax_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -37,13 +37,13 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -37,13 +37,13 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: SplitParam<P>) { required init(device: MTLDevice, param: SplitParam<P>, initContext: InitContext) {
// param.output.initTexture(device: device, computePrecision: computePrecision) // param.output.initTexture(device: device, computePrecision: computePrecision)
let num = param.outputList.count let num = param.outputList.count
let rank = param.input.tensorDim.cout() let rank = param.input.tensorDim.cout()
assert(num >= 2 && num <= 4) assert(num >= 2 && num <= 4)
for output in param.outputList { for output in param.outputList {
output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision) output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} }
smp = SplitMetalParam.init() smp = SplitMetalParam.init()
smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3])) smp.idim = (Int32(param.input.dim[0]), Int32(param.input.dim[1]), Int32(param.input.dim[2]), Int32(param.input.dim[3]))
...@@ -81,10 +81,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -81,10 +81,10 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
if v == "normal" { if v == "normal" {
fatalError("split unsupported") fatalError("split unsupported")
} }
if computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_float") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_float", initContext: initContext)
} else if computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_half") super.init(device: device, inFunctionName: "split_\(rank)_\(num)_\(v)_half", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -33,12 +33,12 @@ class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -33,12 +33,12 @@ class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
required init(device: MTLDevice, param: FeedParam<P>) { required init(device: MTLDevice, param: FeedParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "texture2d_to_2d_array_half") super.init(device: device, inFunctionName: "texture2d_to_2d_array_half", initContext: initContext)
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "texture2d_to_2d_array") super.init(device: device, inFunctionName: "texture2d_to_2d_array", initContext: initContext)
} else { } else {
fatalError() fatalError()
} }
......
...@@ -22,8 +22,8 @@ struct TransposeMetalParam { ...@@ -22,8 +22,8 @@ struct TransposeMetalParam {
class TransposeKernel<P: PrecisionType>: Kernel, Computable { class TransposeKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: TransposeMetalParam = TransposeMetalParam.init() var metalParam: TransposeMetalParam = TransposeMetalParam.init()
required init(device: MTLDevice, param: TransposeParam<P>) { required init(device: MTLDevice, param: TransposeParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, computePrecision: computePrecision) param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
let rank = param.input.tensorDim.cout() let rank = param.input.tensorDim.cout()
var axis: [Int] = [0, 1, 2, 3] var axis: [Int] = [0, 1, 2, 3]
for i in 0..<param.axis.count { for i in 0..<param.axis.count {
...@@ -43,13 +43,13 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable { ...@@ -43,13 +43,13 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable {
metalParam.oC = Int32(param.output.dim[3]) metalParam.oC = Int32(param.output.dim[3])
metalParam.axis = (Int32(naxis[0]), Int32(naxis[1]), Int32(naxis[2]), Int32(naxis[3])) metalParam.axis = (Int32(naxis[0]), Int32(naxis[1]), Int32(naxis[2]), Int32(naxis[3]))
var kernelFunc = "transpose_undefined" var kernelFunc = "transpose_undefined"
if computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.input.transpose == axis { if param.input.transpose == axis {
kernelFunc = "transpose_copy_half" kernelFunc = "transpose_copy_half"
} else { } else {
kernelFunc = "transpose_\(rank)_half" kernelFunc = "transpose_\(rank)_half"
} }
} else if computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.input.transpose == axis { if param.input.transpose == axis {
kernelFunc = "transpose_copy_float" kernelFunc = "transpose_copy_float"
} else { } else {
...@@ -60,7 +60,7 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable { ...@@ -60,7 +60,7 @@ class TransposeKernel<P: PrecisionType>: Kernel, Computable {
} }
print("===========>", kernelFunc) print("===========>", kernelFunc)
print(metalParam) print(metalParam)
super.init(device: device, inFunctionName: kernelFunc) super.init(device: device, inFunctionName: kernelFunc, initContext: initContext)
} }
func compute(commandBuffer: MTLCommandBuffer, param: TransposeParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: TransposeParam<P>) throws {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void batchnorm(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 * nscale [[buffer(0)]],
const device float4 * nbias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const float4 input = inTexture.read(gid.xy, gid.z);
float4 output = input * nscale[gid.z] + nbias[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void batchnorm_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 * newScale [[buffer(0)]],
const device half4 * newBias [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
const half4 input = inTexture.read(gid.xy, gid.z);
half4 output = input * newScale[gid.z] + newBias[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
//
// BatchNormRelu.metal
// paddle-mobile
//
#include <metal_stdlib>
using namespace metal;
struct MetalConvParam {
short offsetX;
short offsetY;
short offsetZ;
ushort strideX;
ushort strideY;
};
kernel void batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *new_scale [[buffer(0)]],
const device float4 *new_biase [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
float4 input;
float4 output;
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
input = inTexture.sample(sample, gid.x, gid.y, gid.z);
output = fmax(input * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(bilinear_interp, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
P w = gid.x * pm.ratio_w;
P h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
P w1lambda = w - w0, h1lambda = h - h0;
P w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
VECTOR(P, 4) r0 = input.read(uint2(w0, h0), gid.z);
VECTOR(P, 4) r1 = input.read(uint2(w1, h0), gid.z);
VECTOR(P, 4) r2 = input.read(uint2(w0, h1), gid.z);
VECTOR(P, 4) r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1)
+ h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct bilinear_interp_param {
float ratio_h;
float ratio_w;
};
#define P float
#include "BilinearInterp.inc.metal"
#undef P
#define P half
#include "BilinearInterp.inc.metal"
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(boxcoder, P)(texture2d_array<P, access::read> priorBox [[texture(0)]],
texture2d_array<P, access::read> priorBoxVar [[texture(1)]],
texture2d_array<P, access::read> targetBox [[texture(2)]],
texture2d_array<P, access::write> output[[texture(3)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) p = priorBox.read(uint2(0, gid.x), gid.z);
VECTOR(P, 4) pv = priorBoxVar.read(uint2(0, gid.x), gid.z);
VECTOR(P, 4) t;
t[0] = targetBox.read(uint2(0, gid.x), gid.z)[0];
t[1] = targetBox.read(uint2(1, gid.x), gid.z)[0];
t[2] = targetBox.read(uint2(2, gid.x), gid.z)[0];
t[3] = targetBox.read(uint2(3, gid.x), gid.z)[0];
P px = (p.x + p.z) / 2;
P py = (p.y + p.w) / 2;
P pw = p.z - p.x;
P ph = p.w - p.y;
P tx = pv.x * t.x * pw + px;
P ty = pv.y * t.y * ph + py;
P tw = exp(pv.z * t.z) * pw;
P th = exp(pv.w * t.w) * ph;
VECTOR(P, 4) r;
r.x = tx - tw / 2;
r.y = ty - th / 2;
r.z = tx + tw / 2;
r.w = ty + th / 2;
output.write(r, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
#define P float
#include "BoxCoder.inc.metal"
#undef P
#define P half
#include "BoxCoder.inc.metal"
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
inline void xyzn2abcd_1(int xyzn[4], int abcd[4]) {
abcd[0] = abcd[1] = abcd[2] = 0;
abcd[3] = xyzn[0] * 4 + xyzn[3];
}
inline void xyzn2abcd_2(int xyzn[4], int abcd[4]) {
abcd[0] = abcd[1] = 0;
abcd[2] = xyzn[1];
abcd[3] = xyzn[0] * 4 + xyzn[3];
}
inline void xyzn2abcd_3(int xyzn[4], int abcd[4]) {
abcd[0] = 0;
abcd[3] = xyzn[0];
abcd[2] = xyzn[1];
abcd[1] = xyzn[2] * 4 + xyzn[3];
}
inline void xyzn2abcd_4(int C, int xyzn[4], int abcd[4]) {
abcd[2] = xyzn[0];
abcd[1] = xyzn[1];
uint t = xyzn[2] * 4 + xyzn[3];
abcd[0] = t / C;
abcd[3] = t % C;
}
inline void abcd2xyzn_1(int abcd[4], int xyzn[4]) {
xyzn[1] = xyzn[2] = 0;
xyzn[0] = abcd[3] / 4;
xyzn[1] = abcd[3] % 4;
}
inline void abcd2xyzn_2(int abcd[4], int xyzn[4]) {
xyzn[2] = 0;
xyzn[1] = abcd[2];
xyzn[0] = abcd[3] / 4;
xyzn[3] = abcd[3] % 4;
}
inline void abcd2xyzn_3(int abcd[4], int xyzn[4]) {
xyzn[0] = abcd[3];
xyzn[1] = abcd[2];
xyzn[2] = abcd[1] / 4;
xyzn[3] = abcd[1] % 4;
}
inline void abcd2xyzn_4(int C, int abcd[4], int xyzn[4]) {
xyzn[0] = abcd[2];
xyzn[1] = abcd[1];
uint t = abcd[0] * C + abcd[3];
xyzn[2] = t / 4;
xyzn[3] = t % 4;
}
inline void xyzn2abcd(int C, int xyzn[4], int abcd[4]) {
abcd[2] = xyzn[0];
abcd[1] = xyzn[1];
uint t = xyzn[2] * 4 + xyzn[3];
abcd[0] = t / C;
abcd[3] = t % C;
}
inline void abcd2xyzn(int C, int abcd[4], int xyzn[4]) {
xyzn[0] = abcd[2];
xyzn[1] = abcd[1];
uint t = abcd[0] * C + abcd[3];
xyzn[2] = t / 4;
xyzn[3] = t % 4;
}
inline int32_t abcd2index(int32_t dim[4], int32_t abcd[4]) {
int32_t r = abcd[0];
r = r * dim[1] + abcd[1];
r = r * dim[2] + abcd[2];
r = r * dim[3] + abcd[3];
return r;
}
inline void index2abcd(int32_t dim[4], int32_t ind, int32_t abcd[4]) {
abcd[3] = ind % dim[3]; ind /= dim[3];
abcd[2] = ind % dim[2]; ind /= dim[2];
abcd[1] = ind % dim[1]; ind /= dim[1];
abcd[0] = ind;
}
inline void trans(int32_t trans[4], int32_t ipos[4], int32_t opos[4]) {
for (int i = 0; i < 4; i++) {
opos[i] = ipos[trans[i]];
}
}
inline void invtrans(int32_t trans[4], int32_t ipos[4], int32_t opos[4]) {
for (int i = 0; i < 4; i++) {
opos[trans[i]] = ipos[i];
}
}
struct MetalConvParam {
short offsetX;
short offsetY;
short offsetZ;
ushort strideX;
ushort strideY;
ushort dilationX;
ushort dilationY;
};
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
#if V == VX
#define VV x
#elif V == VY
#define VV y
#elif V == VZ
#define VV z
#else
#define VV normal
#endif
#if V == VNORMAL
//kernel void FUNC(concat, R, N, normal, P)(array<texture2d_array<P, access::read>, N> in [[texture(0)]],
// texture2d_array<P, access::read> out_x [[texture(N)]],
// texture2d_array<P, access::write> out [[texture(N+1)]],
// constant ConcatParam & pm [[buffer(0)]],
// uint3 gid [[thread_position_in_grid]]) {
//}
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif
texture2d_array<P, access::read> inx [[texture(N)]],
texture2d_array<P, access::write> out [[texture(N+1)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
ConcatParam cp = pm;
int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4];
VECTOR(P, 4) r = inx.read(gid.xy, gid.z);
for (int i = 0; i < 4; i++) {
xyzn[3] = i;
#if R == 4
xyzn2abcd_4(cp.odim[3], xyzn, abcd);
#else
FUNC_R(xyzn2abcd, R)(xyzn, abcd);
#endif
int k = abcd[cp.axis] - cp.offset;
if (k < 0) continue;
int j = 0;
for (; j < N; j++) {
if (k < cp.vdim[j]) {
break;
}
k -= cp.vdim[j];
}
if (j == N) {
continue;
}
int ta = cp.odim[cp.axis];
abcd[cp.axis] = k;
cp.odim[cp.axis] = cp.vdim[j];
#if R == 4
abcd2xyzn_4(cp.odim[3], abcd, oxyzn);
#else
FUNC_R(abcd2xyzn, R)(abcd, oxyzn);
#endif
cp.odim[cp.axis] = ta;
switch (j) {
case 0: r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
case 1: r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#if N >= 3
case 2: r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 4
case 3: r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 5
case 4: r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
#if N >= 6
case 5: r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; break;
#endif
}
}
out.write(r, gid.xy, gid.z);
}
#endif // V == NORMAL
#if V == VX
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int x = gid.x - pm.offset;
if (x < 0) return;
if (x < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
x -= pm.vdim[0];
if (x < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
x -= pm.vdim[1];
if (x < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
x -= pm.vdim[2];
if (x < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
x -= pm.vdim[3];
if (x < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
x -= pm.vdim[4];
if (x < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(x, gid.y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VX
#if V == VY
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int y = gid.y - pm.offset;
if (y < 0) return;
if (y < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
y -= pm.vdim[0];
if (y < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
y -= pm.vdim[1];
if (y < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
y -= pm.vdim[2];
if (y < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
y -= pm.vdim[3];
if (y < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
y -= pm.vdim[4];
if (y < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(uint2(gid.x, y), gid.z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VY
#if V == VZ
kernel void FUNC(concat, R, N, VV, P)(texture2d_array<P, access::read> in0 [[texture(0)]],
texture2d_array<P, access::read> in1 [[texture(1)]],
#if N >= 3
texture2d_array<P, access::read> in2 [[texture(2)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::read> in3 [[texture(3)]],
#endif // N >= 4
#if N >= 5
texture2d_array<P, access::read> in4 [[texture(4)]],
#endif // N >= 5
#if N >= 6
texture2d_array<P, access::read> in5 [[texture(5)]],
#endif // N >= 6
texture2d_array<P, access::write> out [[texture(N)]],
constant ConcatParam & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
int z = gid.z - pm.offset;
if (z < 0) return;
if (z < pm.vdim[0]) {
VECTOR(P, 4) r = in0.read(gid.xy, gid.z);
out.write(r, gid.xy, gid.z);
return;
}
z -= pm.vdim[0];
if (z < pm.vdim[1]) {
VECTOR(P, 4) r = in1.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#if N >= 3
z -= pm.vdim[1];
if (z < pm.vdim[2]) {
VECTOR(P, 4) r = in2.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 3
#if N >= 4
z -= pm.vdim[2];
if (z < pm.vdim[3]) {
VECTOR(P, 4) r = in3.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 4
#if N >= 5
z -= pm.vdim[3];
if (z < pm.vdim[4]) {
VECTOR(P, 4) r = in4.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 5
#if N >= 6
z -= pm.vdim[4];
if (z < pm.vdim[5]) {
VECTOR(P, 4) r = in5.read(gid.xy, z);
out.write(r, gid.xy, gid.z);
return;
}
#endif // N >= 6
}
#endif // V == VZ
#undef VV
#endif // #ifdef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ConcatParam {
int32_t odim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[6];
};
#define VNORMAL 1
#define VX 2
#define VY 3
#define VZ 4
// >> fast mode
// only support concat_{2,3,4}_{2,3,4,5,6}_y_{float,half}
// only support concat_{3,4}_{2,3,4,5,6}_x_{float,half}
// only support concat_{1,2,3,4}_{2,3,4,5,6}_z_{float,half}
// >> normal mode (loop mode)
// ssd-ar: (R=4, N=3, V=z), (R=3, N=2, V=y), (R=2, N=5, V=x), (R=3, N=5, V=x)
// ssd: (R=2, N=6, V=y), (R=3, N=6, V=y)
// genet: (R=4, N=2, V=normal)
// ssd-ar: (R=3, N=5, V=x)
#define V VX
#define R 3
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=2, N=5, V=x)
#define V VX
#define R 2
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=3, N=2, V=y)
#define V VY
#define R 3
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd-ar: (R=4, N=3, V=z)
#define V VZ
#define R 4
#define N 3
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=2, N=6, V=y)
#define V VY
#define R 2
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
// ssd: (R=3, N=6, V=y)
#define V VY
#define R 3
#define N 6
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VNORMAL
#define R 4
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VY
#define R 2
#define N 2
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
#define V VY
#define R 2
#define N 5
#define P float
#include "ConcatKernel.inc.metal"
#undef P
#define P half
#include "ConcatKernel.inc.metal"
#undef P
#undef N
#undef R
#undef V
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
kernel void conv_add_batch_norm_relu_1x1_half(
texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
const device half4 *new_scale [[buffer(3)]],
const device half4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = fmax((output + float4(biase[gid.z])) * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void conv_add_batch_norm_relu_3x3_half(
texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
const device half4 *new_scale [[buffer(3)]],
const device half4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
output = fmax((output + float4(biase[gid.z])) * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void depthwise_conv_add_batch_norm_relu_3x3_half(
texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
const device half4 *new_scale [[buffer(3)]],
const device half4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
half4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
output = fmax((output + float4(biase[gid.z])) * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
/*---------------------------------------------*/
kernel void conv_add_batch_norm_relu_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_add_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
const device float4 *new_scale [[buffer(3)]],
const device float4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
#pragma mark - convAdd
kernel void conv_add_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = biase[gid.z];
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = biase[gid.z];
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY;
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_5x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = biase[gid.z];
ushort dilation_y = param.dilationY;
float4 input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 2 * dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 2 * dilation_y), i);
for (int j = 0; j < 5; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_1x5(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = biase[gid.z];
ushort dilation_x = param.dilationX;
float4 input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 2 * dilation_x, posInInput.y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x + 2 * dilation_x, posInInput.y), i);
for (int j = 0; j < 5; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = biase[gid.z];
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
#pragma mark - half
kernel void conv_add_1x1_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
half4 output = biase[gid.z];
half4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
// output = output + float4(biase[gid.z]);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
half4 output = biase[gid.z];
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY;
half4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i);
for (int j = 0; j < 9; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(float4(input[j]), float4(weight_x));
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(float4(input[j]), float4(weight_y));
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(float4(input[j]), float4(weight_z));
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(float4(input[j]), float4(weight_w));
}
}
// output = output + float4(biase[gid.z]);
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_add_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
half4 output = biase[gid.z];
half4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
// output = output + float4(biase[gid.z]);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_5x1_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
half4 output = biase[gid.z];
ushort dilation_y = param.dilationY;
half4 input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 2 * dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 2 * dilation_y), i);
for (int j = 0; j < 5; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + float4(biase[gid.z]);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_1x5_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
half4 output = biase[gid.z];
ushort dilation_x = param.dilationX;
half4 input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 2 * dilation_x, posInInput.y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x + 2 * dilation_x, posInInput.y), i);
for (int j = 0; j < 5; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + float4(biase[gid.z]);
outTexture.write(output, gid.xy, gid.z);
}
kernel void test_conv_add_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *biase [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
if (gid.x > 0 || gid.y > 0 || gid.z > 0) { return; }
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY;
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + biase[gid.z];
outTexture.write(output, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#include "Macro.metal"
#pragma mark - convAdd
kernel void FUNC3_(conv_add_1x1, PRELU_TYPE, P)(texture2d_array<P, access::sample> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device VECTOR(P, 4) *weights [[buffer(1)]],
const device VECTOR(P, 4) *biase [[buffer(2)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(3)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
VECTOR(P, 4) output = biase[gid.z];
VECTOR(P, 4) input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample,float2(posInInput.x, posInInput.y), i);
VECTOR(P, 4) weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
VECTOR(P, 4) weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
VECTOR(P, 4) weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
VECTOR(P, 4) weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
// output = output + float4(biase[gid.z]);
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(VECTOR(P, 4)(output), gid.xy, gid.z);
}
kernel void FUNC3_(conv_add_3x3, PRELU_TYPE, P)(texture2d_array<P, access::sample> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device VECTOR(P, 4) *weights [[buffer(1)]],
const device VECTOR(P, 4) *biase [[buffer(2)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(3)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
VECTOR(P, 4) output = biase[gid.z];
ushort dilation_x = param.dilationX;
ushort dilation_y = param.dilationY;
VECTOR(P, 4) input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y - dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y - dilation_y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y + dilation_y), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y + dilation_y), i);
for (int j = 0; j < 9; ++j) {
VECTOR(P, 4) weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
VECTOR(P, 4) weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
VECTOR(P, 4) weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
VECTOR(P, 4) weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
// output = output + float4(biase[gid.z]);
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(VECTOR(P, 4)(output), gid.xy, gid.z);
}
kernel void FUNC3_(conv_add_5x1, PRELU_TYPE, P)(texture2d_array<P, access::sample> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device VECTOR(P, 4) *weights [[buffer(1)]],
const device VECTOR(P, 4) *biase [[buffer(2)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(3)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
VECTOR(P, 4) output = biase[gid.z];;
ushort dilation_y = param.dilationY;
VECTOR(P, 4) input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 2 * dilation_y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - dilation_y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + dilation_y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 2 * dilation_y), i);
for (int j = 0; j < 5; ++j) {
VECTOR(P, 4) weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
VECTOR(P, 4) weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
VECTOR(P, 4) weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
VECTOR(P, 4) weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(VECTOR(P, 4)(output), gid.xy, gid.z);
}
kernel void FUNC3_(conv_add_1x5, PRELU_TYPE, P)(texture2d_array<P, access::sample> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device VECTOR(P, 4) *weights [[buffer(1)]],
const device VECTOR(P, 4) *biase [[buffer(2)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(3)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 5;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
VECTOR(P, 4) output = biase[gid.z];
ushort dilation_x = param.dilationX;
VECTOR(P, 4) input[5];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 2 * dilation_x, posInInput.y), i);
input[1] = inTexture.sample(sample, float2(posInInput.x - dilation_x, posInInput.y), i);
input[2] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[3] = inTexture.sample(sample, float2(posInInput.x + dilation_x, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x + 2 * dilation_x, posInInput.y), i);
for (int j = 0; j < 5; ++j) {
VECTOR(P, 4) weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
VECTOR(P, 4) weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
VECTOR(P, 4) weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
VECTOR(P, 4) weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(VECTOR(P, 4)(output), gid.xy, gid.z);
}
kernel void FUNC3_(depthwise_conv_add_3x3, PRELU_TYPE, P)(texture2d_array<P, access::sample> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device P *weights [[buffer(1)]],
const device VECTOR(P, 4) *biase [[buffer(2)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(3)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(3)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
VECTOR(P, 4) output = biase[gid.z];
VECTOR(P, 4) inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
VECTOR(P, 4) input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(VECTOR(P, 4)(output), gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
#define P float
#define PRELU_CHANNEL prelu_channel
#define PRELU_TYPE prelu_channel
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_CHANNEL
#define PRELU_ELEMENT prelu_element
#define PRELU_TYPE prelu_element
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_ELEMENT
#define PRELU_OTHER prelu_other
#define PRELU_TYPE prelu_other
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_OTHER
#undef P
#define P half
#define PRELU_CHANNEL prelu_channel
#define PRELU_TYPE prelu_channel
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_CHANNEL
#define PRELU_ELEMENT prelu_element
#define PRELU_TYPE prelu_element
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_ELEMENT
#define PRELU_OTHER prelu_other
#define PRELU_TYPE prelu_other
#include "ConvAddPrelu.inc.metal"
#undef PRELU_TYPE
#undef PRELU_OTHER
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
#pragma mark - conv bn relu
kernel void conv_batch_norm_relu_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *new_scale [[buffer(2)]],
const device float4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
const device float4 *new_scale [[buffer(2)]],
const device float4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_batch_norm_relu_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float *weights [[buffer(1)]],
const device float4 *new_scale [[buffer(2)]],
const device float4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
output = fmax(output * new_scale[gid.z] + new_biase[gid.z], 0.0);
outTexture.write(output, gid.xy, gid.z);
}
#pragma mark - half
kernel void conv_batch_norm_relu_1x1_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *new_scale [[buffer(2)]],
const device half4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(float4(input), float4(weight_x));
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(float4(input), float4(weight_y));
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(float4(input), float4(weight_z));
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(float4(input), float4(weight_w));
}
output = fmax(output * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void conv_batch_norm_relu_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *new_scale [[buffer(2)]],
const device half4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(float4(input[j]), float4(weight_x));
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(float4(input[j]), float4(weight_y));
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(float4(input[j]), float4(weight_z));
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(float4(input[j]), float4(weight_w));
}
}
output = fmax(output * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void depthwise_conv_batch_norm_relu_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half *weights [[buffer(1)]],
const device half4 *new_scale [[buffer(2)]],
const device half4 *new_biase [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
half4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
output = fmax(output * float4(new_scale[gid.z]) + float4(new_biase[gid.z]), 0.0);
outTexture.write(half4(output), gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
// conv
#pragma mark -- conv
kernel void conv_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(input[j], weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(input[j], weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(input[j], weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(input[j], weight_w);
}
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void depthwise_conv_3x3(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
float4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
float4 input = inputs[j];
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
output.y += input.y * weights[weithTo + 1 * kernelHXW + j];
output.z += input.z * weights[weithTo + 2 * kernelHXW + j];
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_1x1(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
float4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
float4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(input, weight_x);
float4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(input, weight_y);
float4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(input, weight_z);
float4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(input, weight_w);
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
const ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input[9];
for (uint i = 0; i < input_arr_size; ++i) {
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.x += dot(float4(input[j]), float4(weight_x));
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.y += dot(float4(input[j]), float4(weight_y));
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.z += dot(float4(input[j]), float4(weight_z));
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + j * input_arr_size + i];
output.w += dot(float4(input[j]), float4(weight_w));
}
}
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void depthwise_conv_3x3_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
uint output_slice = gid.z;
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 9;
uint weithTo = gid.z * kernelHXW * 4;
float4 output = float4(0.0);
half4 inputs[9];
inputs[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), output_slice);
inputs[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), output_slice);
inputs[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), output_slice);
inputs[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), output_slice);
inputs[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), output_slice);
inputs[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), output_slice);
inputs[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), output_slice);
inputs[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), output_slice);
inputs[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), output_slice);
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
output.x += float(input.x) * float(weights[weithTo + 0 * kernelHXW + j]);
output.y += float(input.y) * float(weights[weithTo + 1 * kernelHXW + j]);
output.z += float(input.z) * float(weights[weithTo + 2 * kernelHXW + j]);
output.w += float(input.w) * float(weights[weithTo + 3 * kernelHXW + j]);
}
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void conv_1x1_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
ushort2 stride = ushort2(param.strideX, param.strideY);
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint kernelHXW = 1;
uint input_arr_size = inTexture.get_array_size();
uint weithTo = gid.z * kernelHXW * input_arr_size * 4;
float4 output = float4(0.0);
half4 input;
for (uint i = 0; i < input_arr_size; ++i) {
input = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
half4 weight_x = weights[weithTo + 0 * kernelHXW * input_arr_size + i];
output.x += dot(float4(input), float4(weight_x));
half4 weight_y = weights[weithTo + 1 * kernelHXW * input_arr_size + i];
output.y += dot(float4(input), float4(weight_y));
half4 weight_z = weights[weithTo + 2 * kernelHXW * input_arr_size + i];
output.z += dot(float4(input), float4(weight_z));
half4 weight_w = weights[weithTo + 3 * kernelHXW * input_arr_size + i];
output.w += dot(float4(input), float4(weight_w));
}
outTexture.write(half4(output), gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct MetalConvTransposeParam{
ushort kernelW;
ushort kernelH;
ushort strideX;
ushort strideY;
ushort paddingX;
ushort paddingY;
ushort dilationX;
ushort dilationY;
};
kernel void conv_transpose2x2_stride2(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant MetalConvTransposeParam &param [[buffer(0)]],
const device float4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
int input_array_size = inTexture.get_array_size();
int kernel_index_x = gid.x % 2;
int kernel_index_y = gid.y % 2;
int kernel_index = kernel_index_y * 2 + kernel_index_x;
int kernel_to = gid.z * input_array_size * 4 * 4 + (kernel_index * input_array_size);
int input_x = gid.x / 2;
int input_y = gid.y / 2;
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output = float4(0.0);
for (int i = 0; i < input_array_size; ++i) {
float4 input = inTexture.sample(sample, float2(input_x, input_y), i);
float4 kernel_slice0 = weights[kernel_to + input_array_size * 4 * 0 + i];
float4 kernel_slice1 = weights[kernel_to + input_array_size * 4 * 1 + i];
float4 kernel_slice2 = weights[kernel_to + input_array_size * 4 * 2 + i];
float4 kernel_slice3 = weights[kernel_to + input_array_size * 4 * 3 + i];
output.x += dot(input, kernel_slice0);
output.y += dot(input, kernel_slice1);
output.z += dot(input, kernel_slice2);
output.w += dot(input, kernel_slice3);
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_transpose2x2_stride2_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvTransposeParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
int input_array_size = inTexture.get_array_size();
int kernel_index_x = gid.x % 2;
int kernel_index_y = gid.y % 2;
int kernel_index = kernel_index_y * 2 + kernel_index_x;
int kernel_to = gid.z * input_array_size * 4 * 4 + (kernel_index * input_array_size);
int input_x = gid.x / 2;
int input_y = gid.y / 2;
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output = float4(0.0);
for (int i = 0; i < input_array_size; ++i) {
half4 input = inTexture.sample(sample, float2(input_x, input_y), i);
half4 kernel_slice0 = weights[kernel_to + input_array_size * 4 * 0 + i];
half4 kernel_slice1 = weights[kernel_to + input_array_size * 4 * 1 + i];
half4 kernel_slice2 = weights[kernel_to + input_array_size * 4 * 2 + i];
half4 kernel_slice3 = weights[kernel_to + input_array_size * 4 * 3 + i];
output.x += dot(float4(input), float4(kernel_slice0));
output.y += dot(float4(input), float4(kernel_slice1));
output.z += dot(float4(input), float4(kernel_slice2));
output.w += dot(float4(input), float4(kernel_slice3));
}
outTexture.write(half4(output), gid.xy, gid.z);
}
//kernel void conv_transpose(texture2d_array<float, access::sample> inTexture [[texture(0)]],
// texture2d_array<float, access::write> outTexture [[texture(1)]],
// constant MetalConvTransposeParam &param [[buffer(0)]],
// const device float4 *weights [[buffer(1)]],
// uint3 gid [[thread_position_in_grid]]){
// if (gid.x >= outTexture.get_width() ||
// gid.y >= outTexture.get_height() ||
// gid.z >= outTexture.get_array_size()) {
// return;
// }
//
// int input_array_size = inTexture.get_array_size();
//
// uint kernel_one_output_slice = input_array_size * param.kernelW * param.kernelH;
//
// uint kernel_stride_z = gid.z * 4 * (kernel_one_output_slice);
//
// constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
//
// float4 output;
//
// for (int w = 0; w < param.kernelW; ++w) {
// int top = gid.x - w * param.dilationX + param.paddingX;
// int input_x = top / param.strideX;
// if (top < 0 || input_x >= int(inTexture.get_width())) {
// continue;
// }
//
// for (int h = 0; h < param.kernelH; ++h) {
// int top_y = gid.y - h * param.dilationY + param.paddingY;
// int input_y = top_y / param.strideY;
// if (top_y < 0 || input_y >= int(inTexture.get_height())) {
// continue;
// }
//
// uint kernel_index = (w * param.kernelH + h) * inTexture.get_array_size();
//
// for (int slice = 0; slice < input_array_size; ++slice) {
//
// float4 input;
// float4 kernel_slice = weights[kernel_stride_z + 0 * kernel_one_output_slice + kernel_index + slice];
// float4 kernel_slice1 = weights[kernel_stride_z + 1 * kernel_one_output_slice + kernel_index + slice];
//
// float4 kernel_slice2 = weights[kernel_stride_z + 2 * kernel_one_output_slice + kernel_index + slice];
//
// float4 kernel_slice3 = weights[kernel_stride_z + 3 * kernel_one_output_slice + kernel_index + slice];
//
// input = inTexture.sample(sample, float2(input_x, input_y), slice);
// output.x += dot(input, kernel_slice);
// output.y += dot(input, kernel_slice1);
// output.z += dot(input, kernel_slice2);
// output.w += dot(input, kernel_slice3);
// }
// }
// }
//
// outTexture.write(output, gid.xy, gid.z);
//}
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ElementwiseAddParam {
int32_t fast;
int32_t axis;
int32_t ylen;
int32_t xdim[4];
int32_t xtrans[4];
int32_t ydim[4];
int32_t ytrans[4];
};
kernel void elementwise_add(texture2d_array<float, access::read> inputX [[texture(0)]],
texture2d_array<float, access::read> inputY [[texture(1)]],
texture2d_array<float, access::write> outTexture [[texture(2)]],
constant ElementwiseAddParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
float4 rx, ry;
if (pm.fast == 1) {
rx = inputX.read(gid.xy, gid.z);
ry = inputY.read(gid.xy, gid.z);
} else {
rx = inputX.read(gid.xy, gid.z);
int32_t x_xyzn[4] = {int32_t(gid.x), int32_t(gid.y), int32_t(gid.z), 0}, x_abcd[4], t_abcd[4];
int32_t y_abcd[4] = {0, 0, 0, 0}, y_xyzn[4];
int32_t xtrans[4] = {pm.xtrans[0], pm.xtrans[1], pm.xtrans[2], pm.xtrans[3]};
int32_t ytrans[4] = {pm.ytrans[0], pm.ytrans[1], pm.ytrans[2], pm.ytrans[3]};
int32_t yshift = 4 - pm.ylen - pm.axis;
for (int n = 0; n < 4; n++) {
x_xyzn[3] = n;
xyzn2abcd(pm.xdim[3], x_xyzn, x_abcd);
invtrans(xtrans, x_abcd, t_abcd);
for (int k = pm.axis; k < (pm.axis + pm.ylen); k++) {
y_abcd[yshift+k] = t_abcd[k];
}
trans(ytrans, y_abcd, t_abcd);
abcd2xyzn(pm.ydim[3], t_abcd, y_xyzn);
ry[n] = inputY.read(uint2(y_xyzn[0], y_xyzn[1]), y_xyzn[2])[y_xyzn[3]];
}
}
float4 r = rx + ry;
outTexture.write(r, gid.xy, gid.z);
}
kernel void elementwise_add_half(texture2d_array<half, access::read> inputX [[texture(0)]],
texture2d_array<half, access::read> inputY [[texture(1)]],
texture2d_array<half, access::write> outTexture [[texture(2)]],
constant ElementwiseAddParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
half4 rx, ry;
if (pm.fast == 1) {
rx = inputX.read(gid.xy, gid.z);
ry = inputY.read(gid.xy, gid.z);
} else {
rx = inputX.read(gid.xy, gid.z);
int32_t x_xyzn[4] = {int32_t(gid.x), int32_t(gid.y), int32_t(gid.z), 0}, x_abcd[4], t_abcd[4];
int32_t y_abcd[4] = {0, 0, 0, 0}, y_xyzn[4];
int32_t xtrans[4] = {pm.xtrans[0], pm.xtrans[1], pm.xtrans[2], pm.xtrans[3]};
int32_t ytrans[4] = {pm.ytrans[0], pm.ytrans[1], pm.ytrans[2], pm.ytrans[3]};
int32_t yshift = 4 - pm.ylen - pm.axis;
for (int n = 0; n < 4; n++) {
x_xyzn[3] = n;
xyzn2abcd(pm.xdim[3], x_xyzn, x_abcd);
invtrans(xtrans, x_abcd, t_abcd);
for (int k = pm.axis; k < (pm.axis + pm.ylen); k++) {
y_abcd[yshift+k] = t_abcd[k];
}
trans(ytrans, y_abcd, t_abcd);
abcd2xyzn(pm.ydim[3], t_abcd, y_xyzn);
ry[n] = inputY.read(uint2(y_xyzn[0], y_xyzn[1]), y_xyzn[2])[y_xyzn[3]];
}
}
half4 r = rx + ry;
outTexture.write(r, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#include <metal_stdlib>
#include "Macro.metal"
using namespace metal;
kernel void FUNC3_(elementwise_add, PRELU_TYPE, P)(texture2d_array<P, access::read> inputX [[texture(0)]],
texture2d_array<P, access::read> inputY [[texture(1)]],
texture2d_array<P, access::write> outTexture [[texture(2)]],
constant ElementwiseAddParam &pm [[buffer(0)]],
#ifdef PRELU_CHANNEL
const device VECTOR(P, 4) *alpha [[buffer(1)]],
#endif
#ifdef PRELU_ELEMENT
const device VECTOR(P, 4) *alpha [[buffer(1)]],
#endif
#ifdef PRELU_OTHER
const device P *alpha [[buffer(1)]],
#endif
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
VECTOR(P, 4) rx, ry;
if (pm.fast == 1) {
rx = inputX.read(gid.xy, gid.z);
ry = inputY.read(gid.xy, gid.z);
} else {
rx = inputX.read(gid.xy, gid.z);
int32_t x_xyzn[4] = {int32_t(gid.x), int32_t(gid.y), int32_t(gid.z), 0}, x_abcd[4], t_abcd[4];
int32_t y_abcd[4] = {0, 0, 0, 0}, y_xyzn[4];
int32_t xtrans[4] = {pm.xtrans[0], pm.xtrans[1], pm.xtrans[2], pm.xtrans[3]};
int32_t ytrans[4] = {pm.ytrans[0], pm.ytrans[1], pm.ytrans[2], pm.ytrans[3]};
int32_t yshift = 4 - pm.ylen - pm.axis;
for (int n = 0; n < 4; n++) {
x_xyzn[3] = n;
xyzn2abcd(pm.xdim[3], x_xyzn, x_abcd);
invtrans(xtrans, x_abcd, t_abcd);
for (int k = pm.axis; k < (pm.axis + pm.ylen); k++) {
y_abcd[yshift+k] = t_abcd[k];
}
trans(ytrans, y_abcd, t_abcd);
abcd2xyzn(pm.ydim[3], t_abcd, y_xyzn);
ry[n] = inputY.read(uint2(y_xyzn[0], y_xyzn[1]), y_xyzn[2])[y_xyzn[3]];
}
}
VECTOR(P, 4) output = rx + ry;
#ifdef PRELU_CHANNEL
VECTOR(P, 4) alpha_value = alpha[gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_ELEMENT
int alpha_to = (gid.y * outTexture.get_width() + gid.x) * outTexture.get_array_size();
VECTOR(P, 4) alpha_value = alpha[alpha_to + gid.z];
output.x = output.x > 0 ? output.x : (alpha_value.x * output.x);
output.y = output.y > 0 ? output.y : (alpha_value.y * output.y);
output.z = output.z > 0 ? output.z : (alpha_value.z * output.z);
output.w = output.w > 0 ? output.w : (alpha_value.w * output.w);
#endif
#ifdef PRELU_OTHER
P alpha_value = alpha[0];
output.x = output.x > 0 ? output.x : (alpha_value * output.x);
output.y = output.y > 0 ? output.y : (alpha_value * output.y);
output.z = output.z > 0 ? output.z : (alpha_value * output.z);
output.w = output.w > 0 ? output.w : (alpha_value * output.w);
#endif
outTexture.write(output, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ElementwiseAddParam {
int32_t fast;
int32_t axis;
int32_t ylen;
int32_t xdim[4];
int32_t xtrans[4];
int32_t ydim[4];
int32_t ytrans[4];
};
#define P float
#define PRELU_CHANNEL prelu_channel
#define PRELU_TYPE channel
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_CHANNEL
#define PRELU_ELEMENT element
#define PRELU_TYPE prelu_element
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_ELEMENT
#define PRELU_OTHER other
#define PRELU_TYPE prelu_other
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_OTHER
#undef P
#define P half
#define PRELU_CHANNEL channel
#define PRELU_TYPE channel
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_CHANNEL
#define PRELU_ELEMENT element
#define PRELU_TYPE prelu_element
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_ELEMENT
#define PRELU_OTHER other
#define PRELU_TYPE prelu_other
#include "ElementwiseAddPreluKernel.inc.metal"
#undef PRELU_TYPE
#undef PRELU_OTHER
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT2(a, b) a ## b
#define FUNC(m, n, q) CONCAT3_(m, n, q)
#define FUNC_T(m, n) CONCAT2_(m, n)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC_T(fetch, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
device float *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height() ||
gid.z >= inTexture.get_array_size()) {
return;
}
int input_width = inTexture.get_width();
int input_height = inTexture.get_height();
const VECTOR(P, 4) input = inTexture.read(gid.xy, gid.z);
int output_to = 4 * input_width * input_height;
output[gid.z * output_to + 0 * input_width * input_height + gid.y * input_width + gid.x] = input.x;
output[gid.z * output_to + 1 * input_width * input_height + gid.y * input_width + gid.x] = input.y;
output[gid.z * output_to + 2 * input_width * input_height + gid.y * input_width + gid.x] = input.z;
output[gid.z * output_to + 3 * input_width * input_height + gid.y * input_width + gid.x] = input.w;
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
#define P float
#include "FetchKernel.inc.metal"
#undef P
#define P half
#include "FetchKernel.inc.metal"
#undef P
kernel void fetch_placeholder(texture2d_array<float, access::read> inTexture [[texture(0)]],
device float *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
}
kernel void fetch_placeholder_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
device float *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
// 占位函数, 啥也没干
kernel void place_holder(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
}
struct OutputDim {
ushort width;
ushort height;
ushort strideX;
ushort strideY;
};
kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
}
kernel void texture2d_to_2d_array(texture2d<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height()){
return;
}
const float4 input = inTexture.read(gid.xy);
outTexture.write(input, gid.xy, 0);
}
kernel void texture2d_to_2d_array_half(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height()){
return;
}
const half4 input = inTexture.read(gid.xy);
outTexture.write(input, gid.xy, 0);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC2_(a, b) CONCAT2_(a, b)
#define FUNC3_(a, b, c) CONCAT3_(a, b, c)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void nms_fetch_result(texture2d_array<float, access::read> inTexture [[texture(0)]],
device float *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height() ||
gid.z >= inTexture.get_array_size()) {
return;
}
int input_width = inTexture.get_width();
const float4 input = inTexture.read(gid.xy, gid.z);
output[gid.y * input_width + gid.x] = input.x;
}
kernel void nms_fetch_result_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
device float *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height() ||
gid.z >= inTexture.get_array_size()) {
return;
}
int input_width = inTexture.get_width();
const half4 input = inTexture.read(gid.xy, gid.z);
output[gid.y * input_width + gid.x] = input.x;
}
kernel void nms_fetch_bbox(texture2d_array<float, access::read> inTexture [[texture(0)]],
device float4 *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height() ||
gid.z >= inTexture.get_array_size()) {
return;
}
int input_width = inTexture.get_width();
// int input_height = inTexture.get_height();
const float4 input = inTexture.read(gid.xy, gid.z);
output[gid.y * input_width + gid.x] = input;
}
kernel void nms_fetch_bbox_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
device float4 *output [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= inTexture.get_width() ||
gid.y >= inTexture.get_height() ||
gid.z >= inTexture.get_array_size()) {
return;
}
int input_width = inTexture.get_width();
// int input_height = inTexture.get_height();
const half4 input = inTexture.read(gid.xy, gid.z);
output[gid.y * input_width + gid.x] = float4(input);
}
//
// PoolKernel.inc.metal
// paddle-mobile
//
// Created by liuRuiLong on 2018/12/29.
// Copyright © 2018 orange. All rights reserved.
//
#ifdef P
kernel void FUNC2_(pool, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant PoolParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int xmin = gid.x * pm.strideX - pm.paddingX;
int xmax = min(xmin + pm.ksizeX, int(inTexture.get_width()));
xmin = max(xmin, 0);
int ymin = gid.y * pm.strideX - pm.paddingX;
int ymax = min(ymin + pm.ksizeX, int(inTexture.get_height()));
ymin = max(ymin, 0);
VECTOR(P, 4) r = 0;
if (pm.poolType == 0) {
r = inTexture.read(uint2(xmin, ymin), gid.z);
for (int x = xmin; x < xmax; x++) {
for (int y = ymin; y < ymax; y++) {
r = fmax(r, inTexture.read(uint2(x, y), gid.z));
}
}
} else if (pm.poolType == 1) {
for (int x = xmin; x < xmax; x++) {
for (int y = ymin; y < ymax; y++) {
r += inTexture.read(uint2(x, y), gid.z);
}
}
r /= (xmax - xmin) * (ymax - ymin);
}
outTexture.write(r, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Macro.metal"
using namespace metal;
struct PoolParam {
int ksizeX;
int ksizeY;
int strideX;
int strideY;
int paddingX;
int paddingY;
int poolType;
};
#define P float
#import "PoolKernel.inc.metal"
#undef P
#define P half
#import "PoolKernel.inc.metal"
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void prelu_channel(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float4 alpha_value = alpha[gid.z];
float4 output;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
int alpha_to = (gid.y * inTexture.get_width() + gid.x) * inTexture.get_array_size();
float4 alpha_value = alpha[alpha_to + gid.z];
float4 output;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float alpha_value = alpha[0];
float4 output;
output.x = input.x > 0 ? input.x : (alpha_value * input.x);
output.y = input.y > 0 ? input.y : (alpha_value * input.y);
output.z = input.z > 0 ? input.z : (alpha_value * input.z);
output.w = input.w > 0 ? input.w : (alpha_value * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_channel_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
half4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
half4 alpha_value = alpha[gid.z];
half4 output;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_element_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
half4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
int alpha_to = (gid.y * inTexture.get_width() + gid.x) * inTexture.get_array_size();
half4 alpha_value = alpha[alpha_to + gid.z];
half4 output;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_other_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half *alpha [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]){
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
half4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
half alpha_value = alpha[0];
half4 output;
output.x = input.x > 0 ? input.x : (alpha_value * input.x);
output.y = input.y > 0 ? input.y : (alpha_value * input.y);
output.z = input.z > 0 ? input.z : (alpha_value * input.z);
output.w = input.w > 0 ? input.w : (alpha_value * input.w);
outTexture.write(output, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct PriorBoxMetalParam {
float offset;
float stepWidth;
float stepHeight;
float minSize;
float maxSize;
float imageWidth;
float imageHeight;
bool clip;
uint numPriors;
uint aspecRatiosSize;
uint minSizeSize;
uint maxSizeSize;
};
kernel void prior_box(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outBoxTexture [[texture(1)]],
texture2d_array<float, access::write> varianceTexture [[texture(2)]],
const device float *aspect_ratios [[buffer(0)]],
constant PriorBoxMetalParam &param [[buffer(1)]],
const device float4 *variances [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outBoxTexture.get_width() ||
gid.y >= outBoxTexture.get_height() ||
gid.z >= outBoxTexture.get_array_size()) return;
float center_x = (gid.x + param.offset) * param.stepWidth;
float center_y = (gid.y + param.offset) * param.stepHeight;
float box_width, box_height;
if (gid.z < param.aspecRatiosSize) {
float ar = aspect_ratios[gid.z];
box_width = param.minSize * sqrt(ar) / 2;
box_height = param.minSize / sqrt(ar) / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(res, gid.xy, gid.z);
} else if (gid.z >= param.aspecRatiosSize) {
if (param.maxSizeSize > 0) {
box_width = box_height = sqrt(param.minSize * param.maxSize) / 2;
float4 max_box;
max_box.x = (center_x - box_width) / param.imageWidth;
max_box.y = (center_y - box_height) / param.imageHeight;
max_box.z = (center_x + box_width) / param.imageWidth;
max_box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = min(max(max_box, 0.0), 1.0);
} else {
res = max_box;
}
outBoxTexture.write(max_box, gid.xy, gid.z);
}
}
float4 variance = variances[0];
if (gid.z < param.numPriors) {
float4 variances_output;
variances_output.x = variance.x;
variances_output.y = variance.y;
variances_output.z = variance.z;
variances_output.w = variance.w;
varianceTexture.write(variances_output, gid.xy, gid.z);
}
}
kernel void prior_box_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outBoxTexture [[texture(1)]],
texture2d_array<half, access::write> varianceTexture [[texture(2)]],
const device half *aspect_ratios [[buffer(0)]],
constant PriorBoxMetalParam &param [[buffer(1)]],
const device float4 *variances [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outBoxTexture.get_width() ||
gid.y >= outBoxTexture.get_height() ||
gid.z >= outBoxTexture.get_array_size()) return;
float center_x = (gid.x + param.offset) * param.stepWidth;
float center_y = (gid.y + param.offset) * param.stepHeight;
float box_width, box_height;
if (gid.z < param.aspecRatiosSize) {
half ar = aspect_ratios[gid.z];
box_width = param.minSize * sqrt(ar) / 2;
box_height = param.minSize / sqrt(ar) / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(half4(res), gid.xy, gid.z);
} else if (gid.z >= param.aspecRatiosSize) {
if (param.maxSizeSize > 0) {
box_width = box_height = sqrt(param.minSize * param.maxSize) / 2;
float4 max_box;
max_box.x = (center_x - box_width) / param.imageWidth;
max_box.y = (center_y - box_height) / param.imageHeight;
max_box.z = (center_x + box_width) / param.imageWidth;
max_box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = min(max(max_box, 0.0), 1.0);
} else {
res = max_box;
}
outBoxTexture.write(half4(max_box), gid.xy, gid.z);
}
}
float4 variance = variances[0];
if (gid.z < param.numPriors) {
float4 variances_output;
variances_output.x = variance.x;
variances_output.y = variance.y;
variances_output.z = variance.z;
variances_output.w = variance.w;
varianceTexture.write(half4(variances_output), gid.xy, gid.z);
}
}
kernel void prior_box_MinMaxAspectRatiosOrder(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outBoxTexture [[texture(1)]],
texture2d_array<float, access::write> varianceTexture [[texture(2)]],
const device float *aspect_ratios [[buffer(0)]],
constant PriorBoxMetalParam &param [[buffer(1)]],
const device float4 *variances [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outBoxTexture.get_width() ||
gid.y >= outBoxTexture.get_height() ||
gid.z >= outBoxTexture.get_array_size()) return;
float center_x = (gid.x + param.offset) * param.stepWidth;
float center_y = (gid.y + param.offset) * param.stepHeight;
float box_width, box_height;
if (gid.z == 0) {
box_width = box_height = param.minSize / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(res, gid.xy, gid.z);
}
if (gid.z == 1 && param.maxSizeSize > 0) {
box_width = box_height = sqrt(param.minSize * param.maxSize) / 2;
float4 max_box;
max_box.x = (center_x - box_width) / param.imageWidth;
max_box.y = (center_y - box_height) / param.imageHeight;
max_box.z = (center_x + box_width) / param.imageWidth;
max_box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = min(max(max_box, 0.0), 1.0);
} else {
res = max_box;
}
outBoxTexture.write(res, gid.xy, gid.z);
}
int aspect_to = 0;
if (param.maxSizeSize > 0) {
aspect_to = gid.z - 2;
} else {
aspect_to = gid.z - 1;
}
if (aspect_to >= 0 && aspect_to < int(param.aspecRatiosSize)) {
int skip = 0;
for (int i = 0; i < aspect_to + 1; ++i) {
if (fabs(aspect_ratios[i] - 1.) < 1e-6) {
skip += 1;
}
}
aspect_to += skip;
float ar = aspect_ratios[aspect_to];
box_width = param.minSize * sqrt(ar) / 2;
box_height = param.minSize / sqrt(ar) / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(res, gid.xy, gid.z);
}
float4 variance = variances[0];
if (gid.z < param.numPriors) {
float4 variances_output;
variances_output.x = variance.x;
variances_output.y = variance.y;
variances_output.z = variance.z;
variances_output.w = variance.w;
varianceTexture.write(variances_output, gid.xy, gid.z);
}
}
kernel void prior_box_MinMaxAspectRatiosOrder_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outBoxTexture [[texture(1)]],
texture2d_array<half, access::write> varianceTexture [[texture(2)]],
const device half *aspect_ratios [[buffer(0)]],
constant PriorBoxMetalParam &param [[buffer(1)]],
const device float4 *variances [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outBoxTexture.get_width() ||
gid.y >= outBoxTexture.get_height() ||
gid.z >= outBoxTexture.get_array_size()) return;
float center_x = (gid.x + param.offset) * param.stepWidth;
float center_y = (gid.y + param.offset) * param.stepHeight;
float box_width, box_height;
if (gid.z == 0) {
box_width = box_height = param.minSize / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(half4(res), gid.xy, gid.z);
}
if (gid.z == 1 && param.maxSizeSize > 0) {
box_width = box_height = sqrt(param.minSize * param.maxSize) / 2;
float4 max_box;
max_box.x = (center_x - box_width) / param.imageWidth;
max_box.y = (center_y - box_height) / param.imageHeight;
max_box.z = (center_x + box_width) / param.imageWidth;
max_box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = min(max(max_box, 0.0), 1.0);
} else {
res = max_box;
}
outBoxTexture.write(half4(res), gid.xy, gid.z);
}
int aspect_to = 0;
if (param.maxSizeSize > 0) {
aspect_to = gid.z - 2;
} else {
aspect_to = gid.z - 1;
}
if (aspect_to > 0 && aspect_to < int(param.aspecRatiosSize) && fabs(aspect_ratios[aspect_to] - 1.) > 1e-6) {
float ar = aspect_ratios[aspect_to];
box_width = param.minSize * sqrt(ar) / 2;
box_height = param.minSize / sqrt(ar) / 2;
float4 box;
box.x = (center_x - box_width) / param.imageWidth;
box.y = (center_y - box_height) / param.imageHeight;
box.z = (center_x + box_width) / param.imageWidth;
box.w = (center_y + box_height) / param.imageHeight;
float4 res;
if (param.clip) {
res = fmin(fmax(box, 0.0), 1.0);
} else {
res = box;
}
outBoxTexture.write(half4(res), gid.xy, gid.z);
}
float4 variance = variances[0];
if (gid.z < param.numPriors) {
float4 variances_output;
variances_output.x = variance.x;
variances_output.y = variance.y;
variances_output.z = variance.z;
variances_output.w = variance.w;
varianceTexture.write(half4(variances_output), gid.xy, gid.z);
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void relu_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void relu(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(float4(relu), gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define FUNC(f, r1, r2, p) CONCAT4_(f, r1, r2, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
kernel void FUNC(reshape, RIN, ROUT, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant ReshapeParam &rp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4];
ReshapeParam lrp = rp;
int oC = lrp.odim[lrp.otrans[3]];
int iC = lrp.idim[lrp.itrans[3]];
int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3];
VECTOR(P, 4) r;
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if ROUT == 4
xyzn2abcd_4(oC, oxyzn, oabcd);
#else
FUNC_R(xyzn2abcd, ROUT)(oxyzn, oabcd);
#endif
int tabcd[4];
invtrans(lrp.otrans, oabcd, tabcd);
int index = abcd2index(lrp.odim, tabcd);
if (index < count) {
index2abcd(lrp.idim, index, tabcd);
trans(lrp.itrans, tabcd, iabcd);
#if RIN == 4
abcd2xyzn_4(iC, iabcd, ixyzn);
#else
FUNC_R(abcd2xyzn, RIN)(iabcd, ixyzn);
#endif
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
} else {
r[n] = 0;
}
}
outTexture.write(r, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONRITIONS OF ANY KINR, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ReshapeParam {
int32_t idim[4];
int32_t itrans[4];
int32_t odim[4];
int32_t otrans[4];
};
#define P float
#define RIN 4
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
#define P half
#define RIN 4
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 3
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 2
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#define RIN 1
#define ROUT 4
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 3
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 2
#include "ReshapeKernel.inc.metal"
#undef ROUT
#define ROUT 1
#include "ReshapeKernel.inc.metal"
#undef ROUT
#undef RIN
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct resize_bilinear_param {
// int32_t out_h;
// int32_t out_w;
float ratio_h;
float ratio_w;
};
kernel void resize_bilinear(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
float w1lambda = w - w0, h1lambda = h - h0;
float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
float4 r0 = input.read(uint2(w0, h0), gid.z);
float4 r1 = input.read(uint2(w1, h0), gid.z);
float4 r2 = input.read(uint2(w0, h1), gid.z);
float4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
kernel void resize_bilinear_half(texture2d_array<half, access::read> input [[texture(0)]],
texture2d_array<half, access::write> output [[texture(2)]],
constant resize_bilinear_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
half4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
half w = gid.x * pm.ratio_w;
half h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
half w1lambda = w - w0, h1lambda = h - h0;
half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
half4 r0 = input.read(uint2(w0, h0), gid.z);
half4 r1 = input.read(uint2(w1, h0), gid.z);
half4 r2 = input.read(uint2(w0, h1), gid.z);
half4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
output.write(r, gid.xy, gid.z);
}
//
// Scale.metal
// paddle-mobile
//
// Created by liuRuiLong on 2019/1/4.
// Copyright © 2019 orange. All rights reserved.
//
#include <metal_stdlib>
using namespace metal;
kernel void scale(texture2d<float, access::sample> inTexture [[texture(0)]], texture2d<float, access::write> outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
float w_stride = inTexture.get_width() / outTexture.get_width();
float h_stride = inTexture.get_height() / outTexture.get_height();
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x * w_stride, gid.y * h_stride), 0);
outTexture.write(input, gid);
}
kernel void scale_half(texture2d<float, access::sample> inTexture [[texture(0)]], texture2d<half, access::write> outTexture [[texture(1)]], uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
float w_stride = inTexture.get_width() / outTexture.get_width();
float h_stride = inTexture.get_height() / outTexture.get_height();
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, float2(gid.x * w_stride, gid.y * h_stride), 0);
outTexture.write(half4(input), gid);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
kernel void shape() {
}
kernel void shape_half() {
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(softmax, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant SoftmaxParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
// int zsize = inTexture.get_array_size();
P maxv = inTexture.read(uint2(0, gid.y), 0)[0];
int group = sp.K / 4;
int remain = sp.K % 4;
for (int x = 0; x < group; x++) {
VECTOR(P, 4) r = inTexture.read(uint2(x, gid.y), 0);
maxv = max(maxv, max(r[0], max(r[1], max(r[2], r[3]))));
}
if (remain > 0) {
VECTOR(P, 4) r = inTexture.read(uint2(group, gid.y), 0);
for (int i = 0; i < remain; i++) {
maxv = max(maxv, r[i]);
}
}
VECTOR(P, 4) rsum = {0, 0, 0, 0};
for (int x = 0; x < group; x++) {
VECTOR(P, 4) r = inTexture.read(uint2(x, gid.y), 0);
rsum += exp(r - maxv);
}
P sum = rsum[0] + rsum[1] + rsum[2] + rsum[3];
if (remain > 0) {
VECTOR(P, 4) r = inTexture.read(uint2(group, gid.y), 0);
for (int i = 0; i < remain; i++) {
sum += exp(r[i] - maxv);
}
}
VECTOR(P, 4) rr = inTexture.read(gid.xy, gid.z);
rr = exp(rr - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
using namespace metal;
struct SoftmaxParam {
int N;
int K;
};
#define P float
#include "Softmax.inc.metal"
#undef P
#define P half
#include "Softmax.inc.metal"
#undef P
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d
#define CONCAT5_(a, b, c, d, e) a ## _ ## b ## _ ## c ## _ ## d ## _ ## e
#define FUNC(f, r, n, v, p) CONCAT5_(f, r, n, v, p)
#define VECTOR(p, n) CONCAT2(p, n)
#define FUNC_R(f, r) CONCAT2_(f, r)
#if V == VX
#define VV x
#elif V == VY
#define VV y
#elif V == VZ
#define VV z
#else
#define VV normal
#endif
#if V == VY
kernel void FUNC(split, R, N, VV, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif // N >= 4
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
int y = gid.y - sp.offset;
if (y < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
return;
}
y -= sp.vdim[0];
if (y < sp.vdim[1]) {
out2.write(r, uint2(gid.x, y), gid.z);
return;
}
#if N >= 3
y -= sp.vdim[1];
if (y < sp.vdim[2]) {
out3.write(r, uint2(gid.x, y), gid.z);
return;
}
#endif // N >= 3
#if N >= 4
y -= sp.vdim[2];
if (y < sp.vdim[3]) {
out4.write(r, uint2(gid.x, y), gid.z);
return;
}
#endif // N >= 4
}
#endif // V == VY
#if V == VX
kernel void FUNC(split, R, N, VV, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif // N >= 4
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
int x = gid.x;
if (x < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
return;
}
x -= sp.vdim[0];
if (x < sp.vdim[1]) {
out2.write(r, uint2(x, gid.y), gid.z);
return;
}
#if N >= 3
x -= sp.vdim[1];
if (x < sp.vdim[2]) {
out3.write(r, uint2(x, gid.y), gid.z);
return;
}
#endif // N >= 3
#if N >= 4
x -= sp.vdim[2];
if (x < sp.vdim[3]) {
out4.write(r, uint2(x, gid.y), gid.z);
return;
}
#endif // N >= 4
}
#endif // V == VX
#undef VV
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct SplitParam {
int32_t idim[4];
int32_t axis;
int32_t offset;
int32_t trans[4];
int32_t vdim[4];
};
#define VNORMAL 1
#define VX 2
#define VY 3
#define VZ 4
// only support split_{2, 3, 4}_{2, 3, 4}_y_{float, half}
// only support split_{3, 4}_{2, 3, 4}_x_{float, half}
//// ssd-ar: (R=3, N=2, V=y)
#define V VY
#define R 3
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
#undef R
#undef V
//// ssd-ar: (R=2, N=2, V=y)
#define V VY
#define R 2
#define N 2
#define P float
#include "Split.inc.metal"
#undef P
#define P half
#include "Split.inc.metal"
#undef P
#undef N
#undef R
#undef V
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c
#define FUNC(f, r, p) CONCAT3_(f, r, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(transpose, R, P)(texture2d_array<P, access::read> inTexture [[texture(0)]],
texture2d_array<P, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r;
int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0};
int iabcd[4], oabcd[4], ixyzn[4];
for (int n = 0; n < 4; n++) {
oxyzn[3] = n;
#if R == 4
xyzn2abcd_4(pm.oC, oxyzn, iabcd);
#endif // R == 4
#if R == 3
xyzn2abcd_3(oxyzn, oabcd);
#endif // R == 3
#if R == 2
xyzn2abcd_2(oxyzn, oabcd);
#endif // R == 2
iabcd[pm.axis[0]] = oabcd[0];
iabcd[pm.axis[1]] = oabcd[1];
iabcd[pm.axis[2]] = oabcd[2];
iabcd[pm.axis[3]] = oabcd[3];
#if R == 4
abcd2xyzn_4(pm.iC, iabcd, ixyzn);
#endif // R == 4
#if R == 3
abcd2xyzn_3(iabcd, ixyzn);
#endif // R == 3
#if R == 2
abcd2xyzn_2(iabcd, ixyzn);
#endif // R == 2
r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]];
}
outTexture.write(r, gid.xy, gid.z);
}
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct TransposeParam {
int iC;
int oC;
int axis[4];
};
kernel void transpose_copy_float(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
outTexture.write(inTexture.read(gid.xy, gid.z), gid.xy, gid.z);
}
kernel void transpose_copy_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant TransposeParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
outTexture.write(inTexture.read(gid.xy, gid.z), gid.xy, gid.z);
}
#define R 4
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
#define R 3
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
#define R 2
#define P float
#include "TransposeKernel.inc.metal"
#undef P
#define P half
#include "TransposeKernel.inc.metal"
#undef P
#undef R
...@@ -38,8 +38,11 @@ class MulticlassNMSParam<P: PrecisionType>: OpParam { ...@@ -38,8 +38,11 @@ class MulticlassNMSParam<P: PrecisionType>: OpParam {
class MulticlassNMSOp<P: PrecisionType>: Operator<MulticlassNMSKernel<P>, MulticlassNMSParam<P>>, Runable, Creator, InferShaperable{ class MulticlassNMSOp<P: PrecisionType>: Operator<MulticlassNMSKernel<P>, MulticlassNMSParam<P>>, Runable, Creator, InferShaperable{
func inputVariant() -> [String : [Variant]] { func inputVariant() -> [String : [MTLBuffer]] {
return ["Scores" : [para.middleOutput], "BBoxes" : [para.bboxOutput]] guard let scoreBuffer = para.middleOutput.resultBuffer, let bboxBuffer = para.middleOutput.resultBuffer else {
fatalError()
}
return ["Scores" : [scoreBuffer], "BBoxes" : [bboxBuffer]]
} }
func computeMiddleResult(device: MTLDevice, buffer: MTLCommandBuffer) { func computeMiddleResult(device: MTLDevice, buffer: MTLCommandBuffer) {
......
...@@ -47,6 +47,7 @@ class ReluOp<P: PrecisionType>: Operator<ReluKernel<P>, ReluParam<P>>, Runable, ...@@ -47,6 +47,7 @@ class ReluOp<P: PrecisionType>: Operator<ReluKernel<P>, ReluParam<P>>, Runable,
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
// let device = para.output.metalTexture!.device // let device = para.output.metalTexture!.device
// let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose) // let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import Foundation import Foundation
class BlockDesc { public class BlockDesc {
let index: Int let index: Int
let parentIndex: Int let parentIndex: Int
let vars: [VarDesc] public let vars: [VarDesc]
let ops: [OpDesc] let ops: [OpDesc]
init(block: PaddleMobile_Framework_Proto_BlockDesc) { init(block: PaddleMobile_Framework_Proto_BlockDesc) {
index = Int(block.idx) index = Int(block.idx)
...@@ -45,7 +45,7 @@ class BlockDesc { ...@@ -45,7 +45,7 @@ class BlockDesc {
} }
extension BlockDesc: CustomStringConvertible, CustomDebugStringConvertible { extension BlockDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String { public var description: String {
var str = "" var str = ""
for i in 0..<ops.count { for i in 0..<ops.count {
...@@ -61,9 +61,7 @@ extension BlockDesc: CustomStringConvertible, CustomDebugStringConvertible { ...@@ -61,9 +61,7 @@ extension BlockDesc: CustomStringConvertible, CustomDebugStringConvertible {
return str return str
} }
var debugDescription: String { public var debugDescription: String {
return description return description
} }
} }
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import Foundation import Foundation
public class Program { @objc public class Program: NSObject {
let paramPath: String public let paramPath: String
let programDesc: ProgramDesc public let programDesc: ProgramDesc
let scope: Scope public let scope: Scope
init(inProgramDesc: ProgramDesc, inParamPath: String, inScope: Scope) { init(inProgramDesc: ProgramDesc, inParamPath: String, inScope: Scope) {
programDesc = inProgramDesc programDesc = inProgramDesc
paramPath = inParamPath paramPath = inParamPath
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
public class ProgramDesc { public class ProgramDesc {
var blocks: [BlockDesc] = [] public var blocks: [BlockDesc] = []
init(protoProgram: PaddleMobile_Framework_Proto_ProgramDesc) { init(protoProgram: PaddleMobile_Framework_Proto_ProgramDesc) {
for block in protoProgram.blocks { for block in protoProgram.blocks {
self.blocks.append(BlockDesc.init(block: block)) self.blocks.append(BlockDesc.init(block: block))
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
class Scope { public class Scope {
let feedKey: String let feedKey: String
let fetchKey: String let fetchKey: String
func setInput(input: Variant) { func setInput(input: Variant) {
...@@ -29,7 +29,7 @@ class Scope { ...@@ -29,7 +29,7 @@ class Scope {
return vars[feedKey]; return vars[feedKey];
} }
func output() -> Variant? { public func output() -> Variant? {
return vars[fetchKey]; return vars[fetchKey];
} }
...@@ -38,7 +38,7 @@ class Scope { ...@@ -38,7 +38,7 @@ class Scope {
fetchKey = inFetchKey fetchKey = inFetchKey
} }
var vars: [String : Variant] = [:] public var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{ subscript(key: String) -> Variant?{
get { get {
return vars[key] return vars[key]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
enum VarTypeType: Int { public enum VarTypeType: Int {
case ErrorType = -1, case ErrorType = -1,
Bool = 0, Bool = 0,
Int16 = 1, Int16 = 1,
...@@ -56,10 +56,10 @@ enum VarTypeType: Int { ...@@ -56,10 +56,10 @@ enum VarTypeType: Int {
} }
} }
class VarDesc { public class VarDesc {
let name: String public let name: String
let persistable: Bool public let persistable: Bool
let type: VarTypeType public let type: VarTypeType
let tensorDesc: TensorDesc? let tensorDesc: TensorDesc?
init(protoVarDesc: PaddleMobile_Framework_Proto_VarDesc) { init(protoVarDesc: PaddleMobile_Framework_Proto_VarDesc) {
type = VarTypeType.init(rawValue: protoVarDesc.type.type.rawValue) ?? .ErrorType type = VarTypeType.init(rawValue: protoVarDesc.type.type.rawValue) ?? .ErrorType
...@@ -79,7 +79,7 @@ class VarDesc { ...@@ -79,7 +79,7 @@ class VarDesc {
} }
extension VarDesc: CustomStringConvertible, CustomDebugStringConvertible { extension VarDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String { public var description: String {
var str = "" var str = ""
str += "var name \(name): \n" str += "var name \(name): \n"
if let inTensorDesc = tensorDesc { if let inTensorDesc = tensorDesc {
...@@ -93,7 +93,7 @@ extension VarDesc: CustomStringConvertible, CustomDebugStringConvertible { ...@@ -93,7 +93,7 @@ extension VarDesc: CustomStringConvertible, CustomDebugStringConvertible {
return str return str
} }
var debugDescription: String { public var debugDescription: String {
return description return description
} }
} }
...@@ -14,47 +14,42 @@ ...@@ -14,47 +14,42 @@
import Foundation import Foundation
public struct Dim { @objc public class Dim: NSObject {
public init(inDim: [Int]) { private(set) var dims: [Int]
dims = inDim
} @objc public init(inDim: [Int]) {
dims = inDim
public init(inDim: (n: Int, h: Int, w: Int, c: Int)) { }
dims = [inDim.n, inDim.h, inDim.w, inDim.c]
} public func cout() -> Int {
return dims.count
mutating func swapeDimAt(index1: Int, index2: Int) { }
dims.swapAt(index1, index2)
} public func numel() -> Int {
return dims.reduce(1) { $0 * $1 }
func cout() -> Int { }
return dims.count
} public static func ==(left: Dim, right: Dim) -> Bool {
return left.dims == right.dims;
func numel() -> Int { }
return dims.reduce(1) { $0 * $1 }
}
public static func ==(left: Dim, right: Dim) -> Bool {
return left.dims == right.dims;
}
public static func !=(left: Dim, right: Dim) -> Bool { public static func !=(left: Dim, right: Dim) -> Bool {
return left.dims != right.dims; return left.dims != right.dims;
} }
public subscript(index: Int) -> Int { public subscript(index: Int) -> Int {
return dims[index]; return dims[index];
} }
private(set) var dims: [Int] public override var description: String {
private init(){ return "\(dims)"
fatalError() }
}
} func swapeDimAt(index1: Int, index2: Int) {
dims.swapAt(index1, index2)
extension Dim: CustomStringConvertible { }
public var description: String {
return "\(dims)" private override init(){
} fatalError()
}
} }
...@@ -19,13 +19,12 @@ let testTo = 5 ...@@ -19,13 +19,12 @@ let testTo = 5
var isTest = false var isTest = false
let computePrecision: ComputePrecision = .Float32 @objc public class GPUResultHolder: NSObject{
public class GPUResultHolder { @objc public let dim: [Int]
public let dim: [Int] @objc public let capacity: Int
public let capacity: Int @objc public var resultPointer: UnsafeMutablePointer<Float32>?
public var resultPointer: UnsafeMutablePointer<Float32>? @objc public var intermediateResults: [String : [MTLBuffer]]?
public var intermediateResults: [String : [Variant]]? public init(inDim: [Int], inPointer: UnsafeMutablePointer<Float32>?, inCapacity: Int, inIntermediateResults: [String : [MTLBuffer]]? = nil) {
public init(inDim: [Int], inPointer: UnsafeMutablePointer<Float32>?, inCapacity: Int, inIntermediateResults: [String : [Variant]]? = nil) {
dim = inDim dim = inDim
capacity = inCapacity capacity = inCapacity
...@@ -37,29 +36,10 @@ public class GPUResultHolder { ...@@ -37,29 +36,10 @@ public class GPUResultHolder {
intermediateResults = inIntermediateResults intermediateResults = inIntermediateResults
} }
} public override var description: String {
extension GPUResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public var debugDescription: String {
// var str = ""
// str += "Dim: \(dim) \n value:[ "
// if resultArr.count < 20 {
// for d in resultArr {
// str += " \(d) "
// }
// } else {
// for d in stride(from: 0, to: resultArr.count, by: resultArr.count/20) {
// str += " \(resultArr[d]) "
// }
// }
// str += " ]"
// return str
fatalError() fatalError()
} }
public var description: String {
return debugDescription
}
} }
public class Executor<P: PrecisionType> { public class Executor<P: PrecisionType> {
...@@ -69,32 +49,18 @@ public class Executor<P: PrecisionType> { ...@@ -69,32 +49,18 @@ public class Executor<P: PrecisionType> {
let device: MTLDevice let device: MTLDevice
let inflightSemaphore: DispatchSemaphore let inflightSemaphore: DispatchSemaphore
let queue: MTLCommandQueue let queue: MTLCommandQueue
public init(inDevice:MTLDevice, inQueue: MTLCommandQueue, inProgram: Program) throws { init(inDevice:MTLDevice, inQueue: MTLCommandQueue, inProgram: Program, initContext: InitContext) throws {
self.inflightSemaphore = DispatchSemaphore(value: 1) self.inflightSemaphore = DispatchSemaphore(value: 1)
program = inProgram program = inProgram
device = inDevice device = inDevice
queue = inQueue queue = inQueue
// print("before for ")
//print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
//block.ops.count //block.ops.count
for i in 0..<block.ops.count { for i in 0..<block.ops.count {
let opDesc = block.ops[i] let opDesc = block.ops[i]
do { do {
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: opDesc, scope: inProgram.scope, initContext: initContext)
// print("in for i \(i): ")
// print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
//
// if i == 56 {
// print(program.scope.vars["fea_pyramid1_mbox_conf_flat.Flatten.output.1.tmp_0"])
//
// }
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: opDesc, scope: inProgram.scope)
ops.append(op) ops.append(op)
} catch let error { } catch let error {
throw error throw error
...@@ -134,7 +100,7 @@ public class Executor<P: PrecisionType> { ...@@ -134,7 +100,7 @@ public class Executor<P: PrecisionType> {
} }
} }
var outputTextures: [String : [Variant]]? var outputTextures: [String : [MTLBuffer]]?
if except > 0 { if except > 0 {
ops[ops.count - except].computeMiddleResult(device: device, buffer: buffer) ops[ops.count - except].computeMiddleResult(device: device, buffer: buffer)
outputTextures = ops[ops.count - except].inputVariant() outputTextures = ops[ops.count - except].inputVariant()
...@@ -159,7 +125,7 @@ public class Executor<P: PrecisionType> { ...@@ -159,7 +125,7 @@ public class Executor<P: PrecisionType> {
op.delogOutput() op.delogOutput()
} }
*/ */
var resultHolder: GPUResultHolder var resultHolder: GPUResultHolder
if except > 0 { if except > 0 {
resultHolder = GPUResultHolder.init(inDim: [], inPointer: nil, inCapacity: 0, inIntermediateResults: outputTextures) resultHolder = GPUResultHolder.init(inDim: [], inPointer: nil, inCapacity: 0, inIntermediateResults: outputTextures)
......
...@@ -28,9 +28,7 @@ extension Tensorial { ...@@ -28,9 +28,7 @@ extension Tensorial {
} }
} }
public enum ComputePrecision {
case Float32, Float16
}
class Tensor<P: PrecisionType>: Tensorial { class Tensor<P: PrecisionType>: Tensorial {
......
...@@ -71,7 +71,7 @@ extension InputTexture { ...@@ -71,7 +71,7 @@ extension InputTexture {
public class Texture: Tensorial { public class Texture: Tensorial {
var dim: Dim public var dim: Dim
public var tensorDim: Dim public var tensorDim: Dim
public var padToFourDim: Dim public var padToFourDim: Dim
private var textureDesc: MTLTextureDescriptor! private var textureDesc: MTLTextureDescriptor!
...@@ -96,7 +96,7 @@ public class Texture: Tensorial { ...@@ -96,7 +96,7 @@ public class Texture: Tensorial {
return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3])) return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
} }
func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) { public func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) {
transpose = inTranspose transpose = inTranspose
for i in 0..<(4 - tensorDim.cout()) { for i in 0..<(4 - tensorDim.cout()) {
if i != inTranspose[i] { if i != inTranspose[i] {
...@@ -143,7 +143,7 @@ public class Texture: Tensorial { ...@@ -143,7 +143,7 @@ public class Texture: Tensorial {
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil " metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
} }
func updateDims(inTensorDim: Dim, inDim: Dim) { public func updateDims(inTensorDim: Dim, inDim: Dim) {
var fourDim: Dim var fourDim: Dim
if inDim.cout() == 4 { if inDim.cout() == 4 {
fourDim = inDim fourDim = inDim
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#import "CPUCompute.h"
#import "PaddleMobileGPU.h"
#import <UIKit/UIKit.h> #import <UIKit/UIKit.h>
//! Project version number for paddle_mobile. //! Project version number for paddle_mobile.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册