提交 7cc0f8ef 编写于 作者: L liuruilong

run mobilenet+ssd genet

上级 b9965649
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
<rect key="frame" x="0.0" y="0.0" width="375" height="667"/> <rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/> <autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
<subviews> <subviews>
<imageView userInteractionEnabled="NO" contentMode="scaleAspectFit" horizontalHuggingPriority="251" verticalHuggingPriority="251" ambiguous="YES" image="hand.jpg" translatesAutoresizingMaskIntoConstraints="NO" id="ZZh-fw-LwK"> <imageView userInteractionEnabled="NO" contentMode="scaleAspectFit" horizontalHuggingPriority="251" verticalHuggingPriority="251" translatesAutoresizingMaskIntoConstraints="NO" id="ZZh-fw-LwK">
<rect key="frame" x="0.0" y="20" width="375" height="247"/> <rect key="frame" x="0.0" y="20" width="375" height="247"/>
</imageView> </imageView>
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" ambiguous="YES" text="Thread:" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="2EB-m2-a3L"> <label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="Thread:" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="2EB-m2-a3L">
<rect key="frame" x="10" y="538" width="68" height="24"/> <rect key="frame" x="10" y="538" width="68" height="24"/>
<constraints> <constraints>
<constraint firstAttribute="width" constant="68" id="Q5J-tq-JSX"/> <constraint firstAttribute="width" constant="68" id="Q5J-tq-JSX"/>
...@@ -32,19 +32,19 @@ ...@@ -32,19 +32,19 @@
<nil key="textColor"/> <nil key="textColor"/>
<nil key="highlightedColor"/> <nil key="highlightedColor"/>
</label> </label>
<pickerView contentMode="scaleToFill" ambiguous="YES" translatesAutoresizingMaskIntoConstraints="NO" id="DlO-dk-RMr"> <pickerView contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="DlO-dk-RMr">
<rect key="frame" x="88" y="510.5" width="287" height="80"/> <rect key="frame" x="88" y="510.5" width="287" height="80"/>
<constraints> <constraints>
<constraint firstAttribute="height" constant="80" id="Sbi-05-Mwd"/> <constraint firstAttribute="height" constant="80" id="Sbi-05-Mwd"/>
</constraints> </constraints>
</pickerView> </pickerView>
<pickerView contentMode="scaleToFill" ambiguous="YES" translatesAutoresizingMaskIntoConstraints="NO" id="6MG-gv-hD5"> <pickerView contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="6MG-gv-hD5">
<rect key="frame" x="85" y="401" width="290" height="80"/> <rect key="frame" x="85" y="401" width="290" height="80"/>
<constraints> <constraints>
<constraint firstAttribute="height" constant="80" id="yAL-JY-G6b"/> <constraint firstAttribute="height" constant="80" id="yAL-JY-G6b"/>
</constraints> </constraints>
</pickerView> </pickerView>
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" ambiguous="YES" text="Models" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="avL-VK-Kha"> <label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="Models" textAlignment="natural" lineBreakMode="tailTruncation" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="avL-VK-Kha">
<rect key="frame" x="10" y="429" width="65" height="24"/> <rect key="frame" x="10" y="429" width="65" height="24"/>
<constraints> <constraints>
<constraint firstAttribute="width" constant="65" id="6oA-g2-Xq4"/> <constraint firstAttribute="width" constant="65" id="6oA-g2-Xq4"/>
...@@ -54,7 +54,7 @@ ...@@ -54,7 +54,7 @@
<nil key="textColor"/> <nil key="textColor"/>
<nil key="highlightedColor"/> <nil key="highlightedColor"/>
</label> </label>
<button opaque="NO" contentMode="scaleToFill" ambiguous="YES" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="wUL-9N-u1V"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="wUL-9N-u1V">
<rect key="frame" x="16" y="597" width="63.5" height="30"/> <rect key="frame" x="16" y="597" width="63.5" height="30"/>
<color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<state key="normal" title="Image"> <state key="normal" title="Image">
...@@ -64,7 +64,7 @@ ...@@ -64,7 +64,7 @@
<action selector="selectImageAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="5uR-SM-fKO"/> <action selector="selectImageAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="5uR-SM-fKO"/>
</connections> </connections>
</button> </button>
<button opaque="NO" contentMode="scaleToFill" ambiguous="YES" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="XpL-9M-UOp"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="XpL-9M-UOp">
<rect key="frame" x="109.5" y="597" width="63" height="30"/> <rect key="frame" x="109.5" y="597" width="63" height="30"/>
<color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<state key="normal" title="Load"> <state key="normal" title="Load">
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
<action selector="loadAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="fZ5-CQ-jCY"/> <action selector="loadAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="fZ5-CQ-jCY"/>
</connections> </connections>
</button> </button>
<button opaque="NO" contentMode="scaleToFill" ambiguous="YES" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="R90-Yf-S6g"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="R90-Yf-S6g">
<rect key="frame" x="202.5" y="597" width="63.5" height="30"/> <rect key="frame" x="202.5" y="597" width="63.5" height="30"/>
<color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<state key="normal" title="Predict"> <state key="normal" title="Predict">
...@@ -84,7 +84,7 @@ ...@@ -84,7 +84,7 @@
<action selector="predictAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="Iyy-sY-gt4"/> <action selector="predictAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="Iyy-sY-gt4"/>
</connections> </connections>
</button> </button>
<button opaque="NO" contentMode="scaleToFill" ambiguous="YES" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="a3K-ri-NVs"> <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" showsTouchWhenHighlighted="YES" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="a3K-ri-NVs">
<rect key="frame" x="296" y="597" width="63" height="30"/> <rect key="frame" x="296" y="597" width="63" height="30"/>
<color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="0.0" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<state key="normal" title="Clear"> <state key="normal" title="Clear">
...@@ -94,7 +94,7 @@ ...@@ -94,7 +94,7 @@
<action selector="clearAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="JYf-UX-rCR"/> <action selector="clearAct:" destination="BYZ-38-t0r" eventType="touchUpInside" id="JYf-UX-rCR"/>
</connections> </connections>
</button> </button>
<view contentMode="scaleToFill" ambiguous="YES" translatesAutoresizingMaskIntoConstraints="NO" id="w7H-Sk-Rai"> <view contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="w7H-Sk-Rai">
<rect key="frame" x="79.5" y="597" width="30" height="30"/> <rect key="frame" x="79.5" y="597" width="30" height="30"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<constraints> <constraints>
...@@ -102,7 +102,7 @@ ...@@ -102,7 +102,7 @@
<constraint firstAttribute="width" constant="30" id="vYd-Fc-KAj"/> <constraint firstAttribute="width" constant="30" id="vYd-Fc-KAj"/>
</constraints> </constraints>
</view> </view>
<view contentMode="scaleToFill" ambiguous="YES" translatesAutoresizingMaskIntoConstraints="NO" id="T4O-nx-ciH"> <view contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="T4O-nx-ciH">
<rect key="frame" x="266" y="597" width="30" height="30"/> <rect key="frame" x="266" y="597" width="30" height="30"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<constraints> <constraints>
...@@ -110,7 +110,7 @@ ...@@ -110,7 +110,7 @@
<constraint firstAttribute="width" constant="30" id="fXE-S7-ZXL"/> <constraint firstAttribute="width" constant="30" id="fXE-S7-ZXL"/>
</constraints> </constraints>
</view> </view>
<view contentMode="scaleToFill" ambiguous="YES" translatesAutoresizingMaskIntoConstraints="NO" id="976-fk-Kx2"> <view contentMode="scaleToFill" translatesAutoresizingMaskIntoConstraints="NO" id="976-fk-Kx2">
<rect key="frame" x="172.5" y="597" width="30" height="30"/> <rect key="frame" x="172.5" y="597" width="30" height="30"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<constraints> <constraints>
...@@ -118,7 +118,7 @@ ...@@ -118,7 +118,7 @@
<constraint firstAttribute="width" constant="30" id="L4p-hP-s5C"/> <constraint firstAttribute="width" constant="30" id="L4p-hP-s5C"/>
</constraints> </constraints>
</view> </view>
<label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" ambiguous="YES" text="耗时:" lineBreakMode="tailTruncation" numberOfLines="0" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="m5L-O7-P31"> <label opaque="NO" userInteractionEnabled="NO" contentMode="left" horizontalHuggingPriority="251" verticalHuggingPriority="251" text="耗时:" lineBreakMode="tailTruncation" numberOfLines="0" baselineAdjustment="alignBaselines" adjustsFontSizeToFit="NO" translatesAutoresizingMaskIntoConstraints="NO" id="m5L-O7-P31">
<rect key="frame" x="15" y="277" width="350" height="38"/> <rect key="frame" x="15" y="277" width="350" height="38"/>
<constraints> <constraints>
<constraint firstAttribute="height" constant="38" id="6SS-sb-7I2"/> <constraint firstAttribute="height" constant="38" id="6SS-sb-7I2"/>
...@@ -133,7 +133,7 @@ ...@@ -133,7 +133,7 @@
<constraint firstAttribute="width" secondItem="4ey-Xr-U4e" secondAttribute="height" multiplier="6.5:1" id="8c5-FF-lB9"/> <constraint firstAttribute="width" secondItem="4ey-Xr-U4e" secondAttribute="height" multiplier="6.5:1" id="8c5-FF-lB9"/>
</constraints> </constraints>
</imageView> </imageView>
<textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" ambiguous="YES" editable="NO" text="结果:" textAlignment="natural" translatesAutoresizingMaskIntoConstraints="NO" id="VQn-bS-fWp"> <textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" editable="NO" text="结果:" textAlignment="natural" translatesAutoresizingMaskIntoConstraints="NO" id="VQn-bS-fWp">
<rect key="frame" x="10" y="323" width="355" height="70"/> <rect key="frame" x="10" y="323" width="355" height="70"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/> <color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
<constraints> <constraints>
...@@ -203,7 +203,6 @@ ...@@ -203,7 +203,6 @@
</scene> </scene>
</scenes> </scenes>
<resources> <resources>
<image name="hand.jpg" width="564" height="664"/>
<image name="paddle-mobile.png" width="402" height="62"/> <image name="paddle-mobile.png" width="402" height="62"/>
</resources> </resources>
</document> </document>
...@@ -33,7 +33,7 @@ class MobileNet_ssd_hand: Net{ ...@@ -33,7 +33,7 @@ class MobileNet_ssd_hand: Net{
return " \(res)" return " \(res)"
} }
func fetchResult(paddleMobileRes: ResultHolder<Float32>) -> [Float32] { func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
guard let interRes = paddleMobileRes.intermediateResults else { guard let interRes = paddleMobileRes.intermediateResults else {
fatalError(" need have inter result ") fatalError(" need have inter result ")
...@@ -47,13 +47,17 @@ class MobileNet_ssd_hand: Net{ ...@@ -47,13 +47,17 @@ class MobileNet_ssd_hand: Net{
fatalError() fatalError()
} }
var scoreFormatArr: [Float32] = score.metalTexture.realNHWC(dim: (n: score.originDim[0], h: score.originDim[1], w: score.originDim[2], c: score.originDim[3])) var scoreFormatArr: [Float32] = score.metalTexture.realNHWC(dim: (n: score.padToFourDim[0], h: score.padToFourDim[1], w: score.padToFourDim[2], c: score.padToFourDim[3]))
print("score: ")
print(scoreFormatArr.strideArray())
var bboxArr = bbox.metalTexture.float32Array() var bboxArr = bbox.metalTexture.float32Array()
print("bbox: ")
print(bboxArr.strideArray())
let nmsCompute = NMSCompute.init() let nmsCompute = NMSCompute.init()
nmsCompute.scoreThredshold = 0.01 nmsCompute.scoreThredshold = 0.01
nmsCompute.nmsTopK = 200 nmsCompute.nmsTopK = 400
nmsCompute.keepTopK = 200 nmsCompute.keepTopK = 200
nmsCompute.nmsEta = 1.0 nmsCompute.nmsEta = 1.0
nmsCompute.nmsThreshold = 0.45 nmsCompute.nmsThreshold = 0.45
......
...@@ -37,7 +37,7 @@ protocol Net { ...@@ -37,7 +37,7 @@ protocol Net {
var preprocessKernel: CusomKernel { get } var preprocessKernel: CusomKernel { get }
func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void) func getTexture(image: CGImage, getTexture: @escaping (MTLTexture) -> Void)
func resultStr(res: [Float]) -> String func resultStr(res: [Float]) -> String
func fetchResult(paddleMobileRes: ResultHolder<Float32>) -> [Float32] func fetchResult(paddleMobileRes: ResultHolder) -> [Float32]
mutating func load() throws mutating func load() throws
func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws func predict(inTexture: MTLTexture, completion: @escaping ((time:TimeInterval, resultArray: [Float32])) -> Void) throws
...@@ -82,7 +82,7 @@ extension Net { ...@@ -82,7 +82,7 @@ extension Net {
} }
} }
func fetchResult(paddleMobileRes: ResultHolder<Float32>) -> [Float32] { func fetchResult(paddleMobileRes: ResultHolder) -> [Float32] {
return paddleMobileRes.resultArr return paddleMobileRes.resultArr
} }
......
...@@ -87,7 +87,7 @@ class ViewController: UIViewController { ...@@ -87,7 +87,7 @@ class ViewController: UIViewController {
fatalError() fatalError()
} }
// print(result.resultArray) print(result.resultArray.strideArray())
if i == max - 1 { if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate) let time = Date.init().timeIntervalSince(startDate)
DispatchQueue.main.async { DispatchQueue.main.async {
...@@ -109,7 +109,7 @@ class ViewController: UIViewController { ...@@ -109,7 +109,7 @@ class ViewController: UIViewController {
threadPickerView.delegate = self threadPickerView.delegate = self
threadPickerView.dataSource = self threadPickerView.dataSource = self
selectImage = UIImage.init(named: "banana.jpeg") selectImage = UIImage.init(named: "hand.jpg")
selectImageView.image = selectImage selectImageView.image = selectImage
net.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in net.getTexture(image: selectImage!.cgImage!) {[weak self] (texture) in
self?.toPredictTexture = texture self?.toPredictTexture = texture
......
...@@ -113,7 +113,7 @@ extension MTLDevice { ...@@ -113,7 +113,7 @@ extension MTLDevice {
return tensor return tensor
} }
func tensor2texture<P>(value: [P], dim: [Int], transpose: [Int] = [0, 1, 2, 3]) -> MTLTexture { func tensor2texture<P>(value: [P], dim: [Int], transpose: [Int] = [0, 1, 2, 3], inComputePrecision: ComputePrecision = .Float32) -> MTLTexture {
if value.count > 0 { if value.count > 0 {
assert(value.count == dim.reduce(1) { $0 * $1 }) assert(value.count == dim.reduce(1) { $0 * $1 })
} }
...@@ -129,7 +129,13 @@ extension MTLDevice { ...@@ -129,7 +129,13 @@ extension MTLDevice {
textureDesc.height = ndim[1] textureDesc.height = ndim[1]
textureDesc.depth = 1 textureDesc.depth = 1
textureDesc.usage = [.shaderRead, .shaderWrite] textureDesc.usage = [.shaderRead, .shaderWrite]
textureDesc.pixelFormat = .rgba32Float
if inComputePrecision == .Float16 {
textureDesc.pixelFormat = .rgba16Float
} else if inComputePrecision == .Float32 {
textureDesc.pixelFormat = .rgba32Float
}
textureDesc.textureType = .type2DArray textureDesc.textureType = .type2DArray
textureDesc.storageMode = .shared textureDesc.storageMode = .shared
textureDesc.cpuCacheMode = .defaultCache textureDesc.cpuCacheMode = .defaultCache
...@@ -355,12 +361,7 @@ public extension MTLTexture { ...@@ -355,12 +361,7 @@ public extension MTLTexture {
// n c h w - dim // n c h w - dim
func toTensor(dim: (n: Int, c: Int, h: Int, w: Int)) -> [Float32] { func toTensor(dim: (n: Int, c: Int, h: Int, w: Int)) -> [Float32] {
// print("origin dim: \(dim)")
print("texture: ")
print(self)
var textureArray: [Float32] var textureArray: [Float32]
// if texturePrecision == .Float16
if pixelFormat == .rgba32Float { if pixelFormat == .rgba32Float {
textureArray = floatArray { (i : Float32) -> Float32 in textureArray = floatArray { (i : Float32) -> Float32 in
return i return i
...@@ -388,7 +389,6 @@ public extension MTLTexture { ...@@ -388,7 +389,6 @@ public extension MTLTexture {
} }
} }
} }
print(" tensor count -- \(output.count)")
return output return output
} }
......
...@@ -18,14 +18,14 @@ let testTo = 161 ...@@ -18,14 +18,14 @@ let testTo = 161
var isTest = false var isTest = false
let computePrecision: ComputePrecision = .Float32 let computePrecision: ComputePrecision = .Float16
public class ResultHolder<P: PrecisionType> { public class ResultHolder {
public let dim: [Int] public let dim: [Int]
public let resultArr: [P] public let resultArr: [Float32]
public var intermediateResults: [String : [Variant]]? public var intermediateResults: [String : [Variant]]?
public let elapsedTime: Double public let elapsedTime: Double
public init(inDim: [Int], inResult: [P], inElapsedTime: Double, inIntermediateResults: [String : [Variant]]? = nil) { public init(inDim: [Int], inResult: [Float32], inElapsedTime: Double, inIntermediateResults: [String : [Variant]]? = nil) {
dim = inDim dim = inDim
resultArr = inResult resultArr = inResult
elapsedTime = inElapsedTime elapsedTime = inElapsedTime
...@@ -78,7 +78,7 @@ public class Executor<P: PrecisionType> { ...@@ -78,7 +78,7 @@ public class Executor<P: PrecisionType> {
} }
} }
public func predict(input: MTLTexture, dim: [Int], completionHandle: @escaping (ResultHolder<P>) -> Void, preProcessKernle: CusomKernel? = nil, except: Int = 0) throws { public func predict(input: MTLTexture, dim: [Int], completionHandle: @escaping (ResultHolder) -> Void, preProcessKernle: CusomKernel? = nil, except: Int = 0) throws {
guard let buffer = queue.makeCommandBuffer() else { guard let buffer = queue.makeCommandBuffer() else {
throw PaddleMobileError.predictError(message: "CommandBuffer is nil") throw PaddleMobileError.predictError(message: "CommandBuffer is nil")
} }
...@@ -114,12 +114,10 @@ public class Executor<P: PrecisionType> { ...@@ -114,12 +114,10 @@ public class Executor<P: PrecisionType> {
buffer.addCompletedHandler { (commandbuffer) in buffer.addCompletedHandler { (commandbuffer) in
// let inputArr = resInput.floatArray(res: { (p:P) -> P in // let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2]))
// return p //// print(inputArr.strideArray())
// })
// print(inputArr.strideArray())
// //
// writeToLibrary(fileName: "banana", array: inputArr) // writeToLibrary(fileName: "test_image_ssd", array: inputArr)
// print("write to library done") // print("write to library done")
// return // return
// print(inputArr) // print(inputArr)
...@@ -142,16 +140,14 @@ public class Executor<P: PrecisionType> { ...@@ -142,16 +140,14 @@ public class Executor<P: PrecisionType> {
// return // return
let afterDate = Date.init() let afterDate = Date.init()
var resultHolder: ResultHolder<P> var resultHolder: ResultHolder
if except > 0 { if except > 0 {
resultHolder = ResultHolder<P>.init(inDim: [], inResult: [], inElapsedTime: afterDate.timeIntervalSince(beforeDate), inIntermediateResults: outputTextures) resultHolder = ResultHolder.init(inDim: [], inResult: [], inElapsedTime: afterDate.timeIntervalSince(beforeDate), inIntermediateResults: outputTextures)
} else { } else {
let outputVar: Variant = self.program.scope.output()! let outputVar: Variant = self.program.scope.output()!
let output: Texture<P> = outputVar as! Texture<P> let output: Texture<P> = outputVar as! Texture<P>
resultHolder = ResultHolder<P>.init(inDim: output.dim.dims, inResult: output.metalTexture.floatArray(res: { (p:P) -> P in resultHolder = ResultHolder.init(inDim: output.dim.dims, inResult: output.toTensor(), inElapsedTime: afterDate.timeIntervalSince(beforeDate))
return p
}), inElapsedTime: afterDate.timeIntervalSince(beforeDate))
} }
completionHandle(resultHolder) completionHandle(resultHolder)
......
...@@ -168,7 +168,7 @@ public class Loader<P: PrecisionType> { ...@@ -168,7 +168,7 @@ public class Loader<P: PrecisionType> {
} }
} else { } else {
if varDesc.name == fetchKey { if varDesc.name == fetchKey {
scope[varDesc.name] = ResultHolder<P>.init(inDim: [], inResult: [], inElapsedTime: 0.0) scope[varDesc.name] = ResultHolder.init(inDim: [], inResult: [], inElapsedTime: 0.0)
} else if varDesc.name == feedKey { } else if varDesc.name == feedKey {
} }
} }
......
...@@ -59,28 +59,28 @@ class BoxcoderOp<P: PrecisionType>: Operator<BoxcoderKernel<P>, BoxcoderParam<P> ...@@ -59,28 +59,28 @@ class BoxcoderOp<P: PrecisionType>: Operator<BoxcoderKernel<P>, BoxcoderParam<P>
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
// let priorBoxOriginDim = para.priorBox.originDim // let priorBoxpadToFourDim = para.priorBox.padToFourDim
// let priorBoxArray: [Float32] = para.priorBox.metalTexture.realNHWC(dim: (n: priorBoxOriginDim[0], h: priorBoxOriginDim[1], w: priorBoxOriginDim[2], c: priorBoxOriginDim[3])) // let priorBoxArray: [Float32] = para.priorBox.metalTexture.realNHWC(dim: (n: priorBoxpadToFourDim[0], h: priorBoxpadToFourDim[1], w: priorBoxpadToFourDim[2], c: priorBoxpadToFourDim[3]))
// print(" prior box ") // print(" prior box ")
// print(priorBoxArray.strideArray()) // print(priorBoxArray.strideArray())
// //
// let priorBoxVarOriginDim = para.priorBoxVar.originDim // let priorBoxVarpadToFourDim = para.priorBoxVar.padToFourDim
// let priorBoxVarArray: [Float32] = para.priorBoxVar.metalTexture.realNHWC(dim: (n: priorBoxVarOriginDim[0], h: priorBoxVarOriginDim[1], w: priorBoxVarOriginDim[2], c: priorBoxVarOriginDim[3])) // let priorBoxVarArray: [Float32] = para.priorBoxVar.metalTexture.realNHWC(dim: (n: priorBoxVarpadToFourDim[0], h: priorBoxVarpadToFourDim[1], w: priorBoxVarpadToFourDim[2], c: priorBoxVarpadToFourDim[3]))
// print(" prior box var ") // print(" prior box var ")
// print(priorBoxVarArray.strideArray()) // print(priorBoxVarArray.strideArray())
// //
// let targetBoxOriginDim = para.targetBox.originDim // let targetBoxpadToFourDim = para.targetBox.padToFourDim
// let targetBoxArray: [Float32] = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxOriginDim[0], h: targetBoxOriginDim[1], w: targetBoxOriginDim[2], c: targetBoxOriginDim[3])) // let targetBoxArray: [Float32] = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxpadToFourDim[0], h: targetBoxpadToFourDim[1], w: targetBoxpadToFourDim[2], c: targetBoxpadToFourDim[3]))
// print(" target box ") // print(" target box ")
// print(targetBoxArray.strideArray()) // print(targetBoxArray.strideArray())
let targetBoxOriginDim = para.targetBox.originDim let targetBoxpadToFourDim = para.targetBox.padToFourDim
let targetBoxArray = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxOriginDim[0], h: targetBoxOriginDim[1], w: targetBoxOriginDim[2], c: targetBoxOriginDim[3])) let targetBoxArray = para.targetBox.metalTexture.realNHWC(dim: (n: targetBoxpadToFourDim[0], h: targetBoxpadToFourDim[1], w: targetBoxpadToFourDim[2], c: targetBoxpadToFourDim[3]))
print(" target box ") print(" target box ")
print(targetBoxArray.strideArray()) print(targetBoxArray.strideArray())
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(" output ") print(" output ")
print(outputArray.strideArray()) print(outputArray.strideArray())
} }
......
...@@ -65,12 +65,12 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run ...@@ -65,12 +65,12 @@ class ConcatOp<P: PrecisionType>: Operator<ConcatKernel<P>, ConcatParam<P>>, Run
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] { if para.output.transpose == [0, 1, 2, 3] {
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] { } else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: originDim[0], c: originDim[1], h: originDim[2], w: originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3])).strideArray())
} else { } else {
fatalError(" not implemet") fatalError(" not implemet")
} }
......
...@@ -112,7 +112,7 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer ...@@ -112,7 +112,7 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer
func delogOutput() { func delogOutput() {
print(" conv add batchnorm relu output ") print(" conv add batchnorm relu output ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.toTensor().strideArray())
// let _: P? = para.input.metalTexture.logDesc(header: "conv add batchnorm relu input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "conv add batchnorm relu input: ", stridable: false)
// para.filter.logDataPointer(header: "filter data pointer: ") // para.filter.logDataPointer(header: "filter data pointer: ")
// print("filter: \(para.filter)") // print("filter: \(para.filter)")
......
...@@ -110,7 +110,7 @@ class ConvBNReluOp<P: PrecisionType>: Operator<ConvBNReluKernel<P>, ConvBNReluPa ...@@ -110,7 +110,7 @@ class ConvBNReluOp<P: PrecisionType>: Operator<ConvBNReluKernel<P>, ConvBNReluPa
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
} }
} }
...@@ -75,7 +75,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runable, ...@@ -75,7 +75,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runable,
func delogOutput() { func delogOutput() {
print("conv output : ") print("conv output : ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.toTensor().strideArray())
// let _: Float16? = para.output.metalTexture.logDesc() // let _: Float16? = para.output.metalTexture.logDesc()
} }
} }
...@@ -45,9 +45,9 @@ class ConvTransposeOp<P: PrecisionType>: Operator<ConvTransposeKernel<P>, ConvTr ...@@ -45,9 +45,9 @@ class ConvTransposeOp<P: PrecisionType>: Operator<ConvTransposeKernel<P>, ConvTr
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] { if para.output.transpose == [0, 1, 2, 3] {
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] { } else if para.output.transpose == [0, 2, 3, 1] {
let output = para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])) let output = para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3]))
......
...@@ -58,6 +58,6 @@ class DepthConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runa ...@@ -58,6 +58,6 @@ class DepthConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runa
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
} }
} }
...@@ -65,6 +65,6 @@ class DwConvBNReluOp<P: PrecisionType>: Operator<ConvBNReluKernel<P>, ConvBNRelu ...@@ -65,6 +65,6 @@ class DwConvBNReluOp<P: PrecisionType>: Operator<ConvBNReluKernel<P>, ConvBNRelu
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
} }
} }
...@@ -74,9 +74,9 @@ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, Elem ...@@ -74,9 +74,9 @@ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, Elem
print(para.inputY) print(para.inputY)
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] { if para.output.transpose == [0, 1, 2, 3] {
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] { } else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
......
...@@ -61,7 +61,7 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam< ...@@ -61,7 +61,7 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam<
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
} }
} }
...@@ -53,11 +53,11 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{ ...@@ -53,11 +53,11 @@ class PriorBoxKernel<P: PrecisionType>: Kernel, Computable{
param.output.dim = Dim.init(inDim: [n, h, w, c]) param.output.dim = Dim.init(inDim: [n, h, w, c])
param.output.transpose = [0, 1, 2, 3] param.output.transpose = [0, 1, 2, 3]
let imageWidth = Float32(param.inputImage.originDim[3]) let imageWidth = Float32(param.inputImage.padToFourDim[3])
let imageHeight = Float32(param.inputImage.originDim[2]) let imageHeight = Float32(param.inputImage.padToFourDim[2])
let featureWidth = param.input.originDim[3] let featureWidth = param.input.padToFourDim[3]
let featureHeight = param.input.originDim[2] let featureHeight = param.input.padToFourDim[2]
if param.stepW == 0 || param.stepH == 0 { if param.stepW == 0 || param.stepH == 0 {
param.stepW = Float32(imageWidth) / Float32(featureWidth) param.stepW = Float32(imageWidth) / Float32(featureWidth)
......
...@@ -51,13 +51,13 @@ class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runabl ...@@ -51,13 +51,13 @@ class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runabl
func delogOutput() { func delogOutput() {
print(" \(type) input: ") print(" \(type) input: ")
print(para.input.metalTexture.toTensor(dim: (n: para.input.originDim[0], c: para.input.originDim[1], h: para.input.originDim[2], w: para.input.originDim[3])).strideArray()) print(para.input.metalTexture.toTensor(dim: (n: para.input.padToFourDim[0], c: para.input.padToFourDim[1], h: para.input.padToFourDim[2], w: para.input.padToFourDim[3])).strideArray())
print(" \(type) Alpha: ") print(" \(type) Alpha: ")
let _: Float32? = para.alpha.buffer.logDesc(header: " alpha: ", stridable: false) let _: Float32? = para.alpha.buffer.logDesc(header: " alpha: ", stridable: false)
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
} }
// print("softmax delog") // print("softmax delog")
......
...@@ -76,12 +76,12 @@ class PriorBoxOp<P: PrecisionType>: Operator<PriorBoxKernel<P>, PriorBoxParam<P> ...@@ -76,12 +76,12 @@ class PriorBoxOp<P: PrecisionType>: Operator<PriorBoxKernel<P>, PriorBoxParam<P>
print(outputArray) print(outputArray)
// output // output
// print(" \(type) output: ") // print(" \(type) output: ")
// let originDim = para.output.originDim // let padToFourDim = para.output.padToFourDim
// if para.output.transpose == [0, 1, 2, 3] { // if para.output.transpose == [0, 1, 2, 3] {
// let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3]), texturePrecision: computePrecision) // let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]), texturePrecision: computePrecision)
// print(outputArray.strideArray()) // print(outputArray.strideArray())
// } else if para.output.transpose == [0, 2, 3, 1] { // } else if para.output.transpose == [0, 2, 3, 1] {
// print(para.output.metalTexture.toTensor(dim: (n: originDim[0], c: originDim[1], h: originDim[2], w: originDim[3]), texturePrecision: computePrecision).strideArray()) // print(para.output.metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3]), texturePrecision: computePrecision).strideArray())
// } else { // } else {
// print(" not implement") // print(" not implement")
// } // }
......
...@@ -41,8 +41,8 @@ class ReshapeParam<P: PrecisionType>: OpParam { ...@@ -41,8 +41,8 @@ class ReshapeParam<P: PrecisionType>: OpParam {
for i in 0..<s.count { for i in 0..<s.count {
dim[4-s.count+i] = s[i] dim[4-s.count+i] = s[i]
} }
output.originDim = Dim.init(inDim: dim) output.padToFourDim = Dim.init(inDim: dim)
output.dim = output.originDim output.dim = output.padToFourDim
inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs) inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error { } catch let error {
...@@ -74,9 +74,9 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, ...@@ -74,9 +74,9 @@ class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>,
print("reshape delog") print("reshape delog")
// let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "reshape input: ", stridable: false)
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} }
......
...@@ -26,7 +26,7 @@ class SoftmaxParam<P: PrecisionType>: OpParam { ...@@ -26,7 +26,7 @@ class SoftmaxParam<P: PrecisionType>: OpParam {
output.dim = input.dim output.dim = input.dim
output.tensorDim = input.tensorDim output.tensorDim = input.tensorDim
output.originDim = input.originDim output.padToFourDim = input.padToFourDim
} catch let error { } catch let error {
throw error throw error
} }
...@@ -55,8 +55,8 @@ class SoftmaxOp<P: PrecisionType>: Operator<SoftmaxKernel<P>, SoftmaxParam<P>>, ...@@ -55,8 +55,8 @@ class SoftmaxOp<P: PrecisionType>: Operator<SoftmaxKernel<P>, SoftmaxParam<P>>,
print(para.input) print(para.input)
print(para.output) print(para.output)
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray: [Float32] = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} }
} }
...@@ -48,9 +48,9 @@ class TransposeOp<P: PrecisionType>: Operator<TransposeKernel<P>, TransposeParam ...@@ -48,9 +48,9 @@ class TransposeOp<P: PrecisionType>: Operator<TransposeKernel<P>, TransposeParam
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
let originDim = para.output.originDim let padToFourDim = para.output.padToFourDim
if para.output.transpose == [0, 1, 2, 3] { if para.output.transpose == [0, 1, 2, 3] {
let outputArray = para.output.metalTexture.realNHWC(dim: (n: originDim[0], h: originDim[1], w: originDim[2], c: originDim[3])) let outputArray = para.output.metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
print(outputArray.strideArray()) print(outputArray.strideArray())
} else if para.output.transpose == [0, 2, 3, 1] { } else if para.output.transpose == [0, 2, 3, 1] {
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray()) print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
......
...@@ -41,14 +41,28 @@ extension InputTexture { ...@@ -41,14 +41,28 @@ extension InputTexture {
public class Texture<P: PrecisionType>: Tensorial { public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim var dim: Dim
public var tensorDim: Dim public var tensorDim: Dim
public var originDim: Dim public var padToFourDim: Dim
private var textureDesc: MTLTextureDescriptor! private var textureDesc: MTLTextureDescriptor!
public var metalTexture: MTLTexture! public var metalTexture: MTLTexture!
var transpose: [Int] = [0, 1, 2, 3] var transpose: [Int] = [0, 1, 2, 3]
func toTensor() -> [Float32] {
guard padToFourDim.cout() == 4 else {
fatalError("- not support -")
}
return metalTexture.toTensor(dim: (n: padToFourDim[0], c: padToFourDim[1], h: padToFourDim[2], w: padToFourDim[3]))
}
func realNHWC() -> [Float32] {
guard padToFourDim.cout() == 4 else {
fatalError(" - not support - ")
}
return metalTexture.realNHWC(dim: (n: padToFourDim[0], h: padToFourDim[1], w: padToFourDim[2], c: padToFourDim[3]))
}
func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) { func initTexture(device: MTLDevice, inTranspose: [Int] = [0, 1, 2, 3], computePrecision: ComputePrecision = .Float16) {
transpose = inTranspose transpose = inTranspose
let newDim = transpose.map { originDim[$0] } let newDim = transpose.map { padToFourDim[$0] }
let newLayout = transpose.map { layout.layoutWithDim[$0] } let newLayout = transpose.map { layout.layoutWithDim[$0] }
...@@ -93,7 +107,7 @@ public class Texture<P: PrecisionType>: Tensorial { ...@@ -93,7 +107,7 @@ public class Texture<P: PrecisionType>: Tensorial {
} }
tensorDim = inDim tensorDim = inDim
dim = fourDim dim = fourDim
originDim = fourDim padToFourDim = fourDim
layout = DataLayout.init([(.N, fourDim[0]), (.C, fourDim[1]), (.H, fourDim[2]), (.W, fourDim[3])]) layout = DataLayout.init([(.N, fourDim[0]), (.C, fourDim[1]), (.H, fourDim[2]), (.W, fourDim[3])])
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册