提交 3382da34 编写于 作者: L liuruilong

fix precision error

上级 b9876e17
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; }; FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; }; FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC203FB221CBFDBA00B37166 /* test.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC203FA921CBFDBA00B37166 /* test.jpg */; }; FC203FB221CBFDBA00B37166 /* test.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC203FA921CBFDBA00B37166 /* test.jpg */; };
FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */ = {isa = PBXBuildFile; fileRef = FC5E03B121DCE8D90016C137 /* mingren_input_data */; };
FC704C1921D2375300F98BAB /* super_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1721D2375300F98BAB /* super_params */; }; FC704C1921D2375300F98BAB /* super_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1721D2375300F98BAB /* super_params */; };
FC704C1A21D2375300F98BAB /* super_model in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1821D2375300F98BAB /* super_model */; }; FC704C1A21D2375300F98BAB /* super_model in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1821D2375300F98BAB /* super_model */; };
FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1D21D237FC00F98BAB /* combined_mobilenet_params */; }; FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */ = {isa = PBXBuildFile; fileRef = FC704C1D21D237FC00F98BAB /* combined_mobilenet_params */; };
...@@ -32,6 +33,7 @@ ...@@ -32,6 +33,7 @@
FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC9797CE21D6506F00F2FD90 /* mingren.jpg */; }; FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC9797CE21D6506F00F2FD90 /* mingren.jpg */; };
FC9797D121D6616600F2FD90 /* BufferToTexture.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC9797D021D6616600F2FD90 /* BufferToTexture.metal */; }; FC9797D121D6616600F2FD90 /* BufferToTexture.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC9797D021D6616600F2FD90 /* BufferToTexture.metal */; };
FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCBCCC542122EF5400D94F7E /* MetalHelper.swift */; }; FCBCCC552122EF5500D94F7E /* MetalHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCBCCC542122EF5400D94F7E /* MetalHelper.swift */; };
FCCED60521D7646E00BE8D5F /* test_image_super in Resources */ = {isa = PBXBuildFile; fileRef = FCCED60421D7646E00BE8D5F /* test_image_super */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; }; FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */; }; FCF437E8214B6DDB00943429 /* MultiPredictViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */; };
...@@ -68,6 +70,7 @@ ...@@ -68,6 +70,7 @@
FC203FA921CBFDBA00B37166 /* test.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = test.jpg; sourceTree = "<group>"; }; FC203FA921CBFDBA00B37166 /* test.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = test.jpg; sourceTree = "<group>"; };
FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; }; FC27991121343A39000B6BAD /* paddle-mobile-demo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "paddle-mobile-demo-Bridging-Header.h"; sourceTree = "<group>"; };
FC4FD97B2140EE250073E130 /* libc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libc++.tbd"; path = "usr/lib/libc++.tbd"; sourceTree = SDKROOT; }; FC4FD97B2140EE250073E130 /* libc++.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = "libc++.tbd"; path = "usr/lib/libc++.tbd"; sourceTree = SDKROOT; };
FC5E03B121DCE8D90016C137 /* mingren_input_data */ = {isa = PBXFileReference; lastKnownFileType = file; path = mingren_input_data; sourceTree = "<group>"; };
FC704C1721D2375300F98BAB /* super_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_params; sourceTree = "<group>"; }; FC704C1721D2375300F98BAB /* super_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_params; sourceTree = "<group>"; };
FC704C1821D2375300F98BAB /* super_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_model; sourceTree = "<group>"; }; FC704C1821D2375300F98BAB /* super_model */ = {isa = PBXFileReference; lastKnownFileType = file; path = super_model; sourceTree = "<group>"; };
FC704C1D21D237FC00F98BAB /* combined_mobilenet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = combined_mobilenet_params; sourceTree = "<group>"; }; FC704C1D21D237FC00F98BAB /* combined_mobilenet_params */ = {isa = PBXFileReference; lastKnownFileType = file; path = combined_mobilenet_params; sourceTree = "<group>"; };
...@@ -83,6 +86,7 @@ ...@@ -83,6 +86,7 @@
FC9797CE21D6506F00F2FD90 /* mingren.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = mingren.jpg; sourceTree = "<group>"; }; FC9797CE21D6506F00F2FD90 /* mingren.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = mingren.jpg; sourceTree = "<group>"; };
FC9797D021D6616600F2FD90 /* BufferToTexture.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = BufferToTexture.metal; sourceTree = "<group>"; }; FC9797D021D6616600F2FD90 /* BufferToTexture.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = BufferToTexture.metal; sourceTree = "<group>"; };
FCBCCC542122EF5400D94F7E /* MetalHelper.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalHelper.swift; sourceTree = "<group>"; }; FCBCCC542122EF5400D94F7E /* MetalHelper.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalHelper.swift; sourceTree = "<group>"; };
FCCED60421D7646E00BE8D5F /* test_image_super */ = {isa = PBXFileReference; lastKnownFileType = file; path = test_image_super; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; }; FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MultiPredictViewController.swift; sourceTree = "<group>"; }; FCF437E7214B6DDB00943429 /* MultiPredictViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MultiPredictViewController.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */ /* End PBXFileReference section */
...@@ -162,6 +166,8 @@ ...@@ -162,6 +166,8 @@
FC203FA821CBFDBA00B37166 /* images */ = { FC203FA821CBFDBA00B37166 /* images */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC5E03B121DCE8D90016C137 /* mingren_input_data */,
FCCED60421D7646E00BE8D5F /* test_image_super */,
FC9797CE21D6506F00F2FD90 /* mingren.jpg */, FC9797CE21D6506F00F2FD90 /* mingren.jpg */,
FC9797BD21D6045B00F2FD90 /* banana.jpeg */, FC9797BD21D6045B00F2FD90 /* banana.jpeg */,
FC203FA921CBFDBA00B37166 /* test.jpg */, FC203FA921CBFDBA00B37166 /* test.jpg */,
...@@ -308,12 +314,14 @@ ...@@ -308,12 +314,14 @@
isa = PBXResourcesBuildPhase; isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FCCED60521D7646E00BE8D5F /* test_image_super in Resources */,
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */, FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */,
FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */, FC9797CF21D6506F00F2FD90 /* mingren.jpg in Resources */,
FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */, FC704C2221D237FC00F98BAB /* combined_mobilenet_params in Resources */,
FC704C1921D2375300F98BAB /* super_params in Resources */, FC704C1921D2375300F98BAB /* super_params in Resources */,
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */,
FC9797C721D609FB00F2FD90 /* synset.txt in Resources */, FC9797C721D609FB00F2FD90 /* synset.txt in Resources */,
FC5E03B221DCE8D90016C137 /* mingren_input_data in Resources */,
FC704C1A21D2375300F98BAB /* super_model in Resources */, FC704C1A21D2375300F98BAB /* super_model in Resources */,
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */, FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */,
FC9797C221D608E000F2FD90 /* mobilenet_model in Resources */, FC9797C221D608E000F2FD90 /* mobilenet_model in Resources */,
......
...@@ -18,6 +18,34 @@ import CoreMedia ...@@ -18,6 +18,34 @@ import CoreMedia
import paddle_mobile import paddle_mobile
import MetalPerformanceShaders import MetalPerformanceShaders
class FileReader {
let file: UnsafeMutablePointer<FILE>
let fileSize: Int
init(paramPath: String) throws {
guard let tmpFile = fopen(paramPath, "rb") else {
throw PaddleMobileError.loaderError(message: "open param file error" + paramPath)
}
file = tmpFile
fseek(file, 0, SEEK_END)
fileSize = ftell(file)
guard fileSize > 0 else {
throw PaddleMobileError.loaderError(message: "param file size is too small")
}
rewind(file)
}
func read<T>() -> UnsafeMutablePointer<T> {
let ptr = UnsafeMutablePointer<T>.allocate(capacity: MemoryLayout<T>.size * fileSize)
fread(ptr, fileSize, 1, file)
return ptr
}
deinit {
fclose(file)
}
}
enum Platform { enum Platform {
case GPU case GPU
} }
...@@ -66,10 +94,24 @@ class ViewController: UIViewController { ...@@ -66,10 +94,24 @@ class ViewController: UIViewController {
@IBAction func loadAct(_ sender: Any) { @IBAction func loadAct(_ sender: Any) {
runner = Runner.init(inNet: netSupport[modelType]!, commandQueue: MetalHelper.shared.queue) runner = Runner.init(inNet: netSupport[modelType]!, commandQueue: MetalHelper.shared.queue)
if platform == .GPU { if platform == .GPU {
let filePath = Bundle.main.path(forResource: "mingren_input_data", ofType: nil)
let fileReader = try! FileReader.init(paramPath: filePath!)
let pointer: UnsafeMutablePointer<Float32> = fileReader.read()
let buffer = MetalHelper.shared.device.makeBuffer(length: fileReader.fileSize, options: .storageModeShared)
buffer?.contents().copyMemory(from: pointer, byteCount: fileReader.fileSize)
if self.toPredictTexture == nil { if self.toPredictTexture == nil {
runner.getTexture(image: selectImage!.cgImage!) { [weak self] (texture) in
runner.getTexture(inBuffer: buffer!) { [weak self] (texture) in
self?.toPredictTexture = texture self?.toPredictTexture = texture
} }
// runner.getTexture(image: selectImage!.cgImage!) { [weak self] (texture) in
// }
} }
} else { } else {
fatalError( " unsupport " ) fatalError( " unsupport " )
...@@ -108,7 +150,8 @@ class ViewController: UIViewController { ...@@ -108,7 +150,8 @@ class ViewController: UIViewController {
guard let sSelf = self else { guard let sSelf = self else {
fatalError() fatalError()
} }
if let inResultHolder = resultHolder, success {
if success, let inResultHolder = resultHolder {
if i == max - 1 { if i == max - 1 {
let time = Date.init().timeIntervalSince(startDate) let time = Date.init().timeIntervalSince(startDate)
......
...@@ -15,13 +15,6 @@ ...@@ -15,13 +15,6 @@
import Foundation import Foundation
class SuperResolutionPreProccess: CusomKernel {
init(device: MTLDevice) {
let s = Shape.init(inWidth: 224, inHeight: 224, inChannel: 3)
super.init(device: device, inFunctionName: "super_resolution_preprocess", outputDim: s, usePaddleMobileLib: false)
}
}
public class SuperResolutionNet: Net{ public class SuperResolutionNet: Net{
override public func resultStr(res: ResultHolder) -> String { override public func resultStr(res: ResultHolder) -> String {
return "未实现" return "未实现"
......
...@@ -63,7 +63,8 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam< ...@@ -63,7 +63,8 @@ class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam<
func delogOutput() { func delogOutput() {
print(" \(type) output: ") print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[3], h: para.output.padToFourDim[2], w: para.output.padToFourDim[1])).strideArray())
} }
} }
...@@ -18,13 +18,14 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -18,13 +18,14 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddParam<P>) { required init(device: MTLDevice, param: ConvAddParam<P>) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: computePrecision)
param.filter.initBuffer(device: device, precision: computePrecision) let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1])
param.filter.initBuffer(device: device, precision: computePrecision, padWhenOneC: padWhenOneC)
param.y.initBuffer(device: device, precision: computePrecision) param.y.initBuffer(device: device, precision: computePrecision)
if computePrecision == .Float16 { if computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1_half") super.init(device: device, inFunctionName: "conv_add_1x1_half")
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_half") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_half")
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_3x3_half") super.init(device: device, inFunctionName: "conv_add_3x3_half")
...@@ -38,7 +39,7 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable { ...@@ -38,7 +39,7 @@ class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
} else if computePrecision == .Float32 { } else if computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1") super.init(device: device, inFunctionName: "conv_add_1x1")
} else if param.filter.channel == 1 { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3") super.init(device: device, inFunctionName: "depthwise_conv_add_3x3")
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1") super.init(device: device, inFunctionName: "conv_add_5x1")
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
let testTo = 22 let testTo = 5
var isTest = false var isTest = false
...@@ -144,15 +144,13 @@ public class Executor<P: PrecisionType> { ...@@ -144,15 +144,13 @@ public class Executor<P: PrecisionType> {
guard let SSelf = self else { guard let SSelf = self else {
fatalError() fatalError()
} }
//将输入写进文件 //将输入写进文件
/* /*
let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2])) let inputArr = resInput.toTensor(dim: (n: dim[0], c: dim[3], h: dim[1], w: dim[2]))
print(dim) print(dim)
writeToLibrary(fileName: "test_image_mingren", array: inputArr) writeToLibrary(fileName: "test_image_super", array: inputArr)
print(" write done ") print(" write done ")
return return
*/ */
......
...@@ -97,7 +97,7 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -97,7 +97,7 @@ class Tensor<P: PrecisionType>: Tensorial {
func initBuffer(device: MTLDevice, precision: ComputePrecision = .Float16, convertToNHWC: Bool = true, withTranspose: Bool = false) { func initBuffer(device: MTLDevice, precision: ComputePrecision = .Float16, padWhenOneC: Bool = false, convertToNHWC: Bool = true, withTranspose: Bool = false) {
if convertToNHWC { if convertToNHWC {
// print(layout) // print(layout)
convert(to: DataLayout.NHWC()) convert(to: DataLayout.NHWC())
...@@ -145,7 +145,7 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -145,7 +145,7 @@ class Tensor<P: PrecisionType>: Tensorial {
case .Float16: case .Float16:
float32ToFloat16(input: floatPointer, output: buffer.contents(), count: count) float32ToFloat16(input: floatPointer, output: buffer.contents(), count: count)
} }
} else if C == 1 { } else if C == 1 && !padWhenOneC {
buffer = device.makeBuffer(length: numel() * precisionSize) buffer = device.makeBuffer(length: numel() * precisionSize)
switch precision { switch precision {
case .Float32: case .Float32:
...@@ -238,10 +238,32 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -238,10 +238,32 @@ class Tensor<P: PrecisionType>: Tensorial {
data.release() data.release()
} }
var n: Int {
get {
if dim.cout() == 4 {
if layout == DataLayout.NCHW() {
return dim[0]
} else if layout == DataLayout.NHWC() {
return dim[0]
} else {
fatalError(" unsupport ")
}
} else {
fatalError()
}
}
}
var width: Int { var width: Int {
get { get {
if dim.cout() == 4 { if dim.cout() == 4 {
return dim[1] if layout == DataLayout.NHWC() {
return dim[2]
} else if layout == DataLayout.NCHW() {
return dim[3]
} else {
fatalError(" unsupport ")
}
} else { } else {
fatalError() fatalError()
} }
...@@ -251,7 +273,13 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -251,7 +273,13 @@ class Tensor<P: PrecisionType>: Tensorial {
var height: Int { var height: Int {
get { get {
if dim.cout() == 4 { if dim.cout() == 4 {
return dim[2] if layout == DataLayout.NHWC() {
return dim[1]
} else if layout == DataLayout.NCHW() {
return dim[2]
} else {
fatalError(" unsupport ")
}
} else { } else {
fatalError() fatalError()
} }
...@@ -261,7 +289,13 @@ class Tensor<P: PrecisionType>: Tensorial { ...@@ -261,7 +289,13 @@ class Tensor<P: PrecisionType>: Tensorial {
var channel: Int { var channel: Int {
get { get {
if dim.cout() == 4 { if dim.cout() == 4 {
return dim[3] if layout == DataLayout.NHWC() {
return dim[3]
} else if layout == DataLayout.NCHW() {
return dim[1]
} else {
fatalError(" unsupport ")
}
} else { } else {
fatalError() fatalError()
} }
......
...@@ -21,7 +21,7 @@ int main() { ...@@ -21,7 +21,7 @@ int main() {
paddle_mobile::PaddleMobileConfigInternal config; paddle_mobile::PaddleMobileConfigInternal config;
config.load_when_predict = true; config.load_when_predict = true;
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile(config); paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile(config);
// paddle_mobile.SetThreadNum(4); // paddle_mobile.SetThreadNum(4);
auto time1 = paddle_mobile::time(); auto time1 = paddle_mobile::time();
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
...@@ -38,84 +38,84 @@ int main() { ...@@ -38,84 +38,84 @@ int main() {
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms" std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms"
<< std::endl; << std::endl;
// 300*300 // 300*300
std::vector<float> input; // std::vector<float> input;
std::vector<int64_t> dims{1, 1, 300, 300}; // std::vector<int64_t> dims{1, 1, 300, 300};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); // GetInput<float>(g_test_image_1x3x224x224, &input, dims);
//
std::vector<float> vec_result; // std::vector<float> vec_result;
auto time3 = paddle_mobile::time(); auto time3 = paddle_mobile::time();
int max = 10; int max = 1;
for (int i = 0; i < max; ++i) { // for (int i = 0; i < max; ++i) {
auto time5 = paddle_mobile::time(); // auto time5 = paddle_mobile::time();
vec_result = paddle_mobile.Predict(input, dims); // vec_result = paddle_mobile.Predict(input, dims);
auto time6 = paddle_mobile::time(); // auto time6 = paddle_mobile::time();
std::cout << "300 predict cost :第" << i << ": " // std::cout << "300 predict cost :第" << i << ": "
<< paddle_mobile::time_diff(time5, time6) << "ms" << std::endl; // << paddle_mobile::time_diff(time5, time6) << "ms" << std::endl;
} // }
auto time4 = paddle_mobile::time(); // auto time4 = paddle_mobile::time();
//
std::cout << "300 predict cost :" // std::cout << "300 predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms" // << paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl; // << std::endl;
auto biggest = // auto biggest =
std::max_element(std::begin(vec_result), std::end(vec_result)); // std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << "300 Max element is " << *biggest << " at position " // std::cout << "300 Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl; // << std::distance(std::begin(vec_result), biggest) << std::endl;
//
// 500*500 // // 500*500
std::vector<float> vec_result2; // std::vector<float> vec_result2;
//
std::vector<float> input2; // std::vector<float> input2;
std::vector<int64_t> dims2{1, 1, 500, 500}; // std::vector<int64_t> dims2{1, 1, 500, 500};
GetInput<float>(g_test_image_1x3x224x224, &input2, dims2); // GetInput<float>(g_test_image_1x3x224x224, &input2, dims2);
//
time3 = paddle_mobile::time(); // time3 = paddle_mobile::time();
for (int i = 0; i < max; ++i) { // for (int i = 0; i < max; ++i) {
auto time5 = paddle_mobile::time(); // auto time5 = paddle_mobile::time();
vec_result2 = paddle_mobile.Predict(input2, dims2); // vec_result2 = paddle_mobile.Predict(input2, dims2);
auto time6 = paddle_mobile::time(); // auto time6 = paddle_mobile::time();
std::cout << "500 predict cost :第" << i << ": " // std::cout << "500 predict cost :第" << i << ": "
<< paddle_mobile::time_diff(time5, time6) << "ms" << std::endl; // << paddle_mobile::time_diff(time5, time6) << "ms" << std::endl;
} // }
//
time4 = paddle_mobile::time(); // time4 = paddle_mobile::time();
std::cout << "500 predict cost :" // std::cout << "500 predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms" // << paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl; // << std::endl;
biggest = std::max_element(std::begin(vec_result2), std::end(vec_result2)); // biggest = std::max_element(std::begin(vec_result2), std::end(vec_result2));
std::cout << "500 Max element is " << *biggest << " at position " // std::cout << "500 Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result2), biggest) << std::endl; // << std::distance(std::begin(vec_result2), biggest) << std::endl;
//
// 1000*1000 // // 1000*1000
//
std::vector<float> vec_result3; // std::vector<float> vec_result3;
std::vector<float> input3; // std::vector<float> input3;
std::vector<int64_t> dims3{1, 1, 1000, 1000}; // std::vector<int64_t> dims3{1, 1, 1000, 1000};
GetInput<float>(g_test_image_1x3x224x224, &input3, dims3); // GetInput<float>(g_test_image_1x3x224x224, &input3, dims3);
//
time3 = paddle_mobile::time(); // time3 = paddle_mobile::time();
//
for (int i = 0; i < max; ++i) { // for (int i = 0; i < max; ++i) {
auto time5 = paddle_mobile::time(); // auto time5 = paddle_mobile::time();
vec_result3 = paddle_mobile.Predict(input3, dims3); // vec_result3 = paddle_mobile.Predict(input3, dims3);
auto time6 = paddle_mobile::time(); // auto time6 = paddle_mobile::time();
std::cout << "1000*1000 predict cost :第" << i << ": " // std::cout << "1000*1000 predict cost :第" << i << ": "
<< paddle_mobile::time_diff(time5, time6) << "ms" << std::endl; // << paddle_mobile::time_diff(time5, time6) << "ms" << std::endl;
} // }
time4 = paddle_mobile::time(); // time4 = paddle_mobile::time();
std::cout << "1000*1000 predict cost :" // std::cout << "1000*1000 predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms" // << paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl; // << std::endl;
biggest = std::max_element(std::begin(vec_result3), std::end(vec_result3)); // biggest = std::max_element(std::begin(vec_result3), std::end(vec_result3));
std::cout << "1000*1000 Max element is " << *biggest << " at position " // std::cout << "1000*1000 Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result3), biggest) << std::endl; // << std::distance(std::begin(vec_result3), biggest) << std::endl;
// 224*224 // 224*224
std::vector<float> vec_result4; std::vector<float> vec_result4;
std::vector<float> input4; std::vector<float> input4;
std::vector<int64_t> dims4{1, 1, 224, 224}; std::vector<int64_t> dims4{1, 1, 300, 300};
GetInput<float>(g_test_image_1x3x224x224, &input4, dims4); GetInput<float>(g_test_image_1x3x224x224, &input4, dims4);
time3 = paddle_mobile::time(); time3 = paddle_mobile::time();
...@@ -127,13 +127,13 @@ int main() { ...@@ -127,13 +127,13 @@ int main() {
<< paddle_mobile::time_diff(time5, time6) << "ms" << std::endl; << paddle_mobile::time_diff(time5, time6) << "ms" << std::endl;
} }
time4 = paddle_mobile::time(); auto time4 = paddle_mobile::time();
std::cout << "224*224 predict cost :" std::cout << "224*224 predict cost :"
<< paddle_mobile::time_diff(time3, time4) / max << "ms" << paddle_mobile::time_diff(time3, time4) / max << "ms"
<< std::endl; << std::endl;
biggest = std::max_element(std::begin(vec_result4), std::end(vec_result4)); // biggest = std::max_element(std::begin(vec_result4), std::end(vec_result4));
std::cout << "224*224 Max element is " << *biggest << " at position " // std::cout << "224*224 Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result4), biggest) << std::endl; // << std::distance(std::begin(vec_result4), biggest) << std::endl;
} }
return 0; return 0;
......
...@@ -62,7 +62,7 @@ static const char *g_imgfssd_ar = "../images/test_image_ssd_ar"; ...@@ -62,7 +62,7 @@ static const char *g_imgfssd_ar = "../images/test_image_ssd_ar";
static const char *g_imgfssd_ar1 = "../images/003_0001.txt"; static const char *g_imgfssd_ar1 = "../images/003_0001.txt";
static const char *g_img = "../images/img.bin"; static const char *g_img = "../images/img.bin";
static const char *g_yolo_img = "../images/in_put_1_3_416_416_2"; static const char *g_yolo_img = "../images/in_put_1_3_416_416_2";
static const char *g_super_img = "../images/test_image_super"; static const char *g_super_img = "../images/mingren_input_data";
static const char *g_mobilenet_img = "../images/image"; static const char *g_mobilenet_img = "../images/image";
using paddle_mobile::framework::DDim; using paddle_mobile::framework::DDim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册