未验证 提交 4359398d 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #518 from codeWorm2015/metal

add kernel imp
...@@ -77,6 +77,7 @@ metal/Pods/ ...@@ -77,6 +77,7 @@ metal/Pods/
SwiftProtobuf.framework SwiftProtobuf.framework
paddle-mobile.xcworkspace paddle-mobile.xcworkspace
metal/models/ metal/models/
metal/images/
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key> <key>paddle-mobile-demo.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>3</integer> <integer>4</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -13,19 +13,36 @@ ...@@ -13,19 +13,36 @@
limitations under the License. */ limitations under the License. */
import UIKit import UIKit
import MetalKit
import paddle_mobile import paddle_mobile
class ViewController: UIViewController { class ViewController: UIViewController {
let device: MTLDevice! = MTLCreateSystemDefaultDevice()
var textureLoader: MTKTextureLoader!
// let queue: MTLCommandQueue
override func viewDidLoad() { override func viewDidLoad() {
super.viewDidLoad() super.viewDidLoad()
let queue = device.makeCommandQueue()
textureLoader = MTKTextureLoader.init(device: device)
guard let appleImage = UIImage.init(named: "apple.jpg"), let cgImage = appleImage.cgImage else {
fatalError(" image nil !")
}
let texture = try? textureLoader.newTexture(cgImage: cgImage, options: [:]) ?! " texture loader error"
guard let inTexture = texture else {
fatalError(" texture is nil !")
}
let loader = Loader<Float>.init() let loader = Loader<Float>.init()
do { do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath) let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(inProgram: program) let executor = try Executor<Float>.init(inProgram: program)
let output = try executor.predict(input: Texture.init()) let output = try executor.predict(input: inTexture, expect: [1, 224, 224, 3])
print(output) print(output)
} catch let error { } catch let error {
print(error) print(error)
......
...@@ -30,11 +30,15 @@ ...@@ -30,11 +30,15 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; }; FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; }; FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; }; FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; }; FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; }; FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; }; FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; }; FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; }; FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
...@@ -65,11 +69,15 @@ ...@@ -65,11 +69,15 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; }; FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; }; FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; }; FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; }; FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; }; FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; }; FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; }; FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; }; FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; };
/* End PBXFileReference section */ /* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */ /* Begin PBXFrameworksBuildPhase section */
...@@ -140,6 +148,7 @@ ...@@ -140,6 +148,7 @@
FC039B9420E11C9A0081E9F8 /* Extensions.swift */, FC039B9420E11C9A0081E9F8 /* Extensions.swift */,
FC039B9520E11C9A0081E9F8 /* Errors.swift */, FC039B9520E11C9A0081E9F8 /* Errors.swift */,
FC039B9620E11C9A0081E9F8 /* Types.swift */, FC039B9620E11C9A0081E9F8 /* Types.swift */,
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */,
); );
path = Common; path = Common;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -157,8 +166,8 @@ ...@@ -157,8 +166,8 @@
FC039BA320E11CBC0081E9F8 /* Operators */ = { FC039BA320E11CBC0081E9F8 /* Operators */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC086BA520E67E8500D85EF7 /* Kernels */,
FCD592FA20E248EC00252966 /* Base */, FCD592FA20E248EC00252966 /* Base */,
FCD592F920E248EC00252966 /* Kernels */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */, FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */, FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
...@@ -185,9 +194,12 @@ ...@@ -185,9 +194,12 @@
path = Program; path = Program;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FCD592F920E248EC00252966 /* Kernels */ = { FC086BA520E67E8500D85EF7 /* Kernels */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FCF2D73720E64E70007AC5F5 /* Kernel.swift */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -308,6 +320,10 @@ ...@@ -308,6 +320,10 @@
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */, FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */, FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
...@@ -461,17 +477,19 @@ ...@@ -461,17 +477,19 @@
CODE_SIGN_IDENTITY = ""; CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic; CODE_SIGN_STYLE = Automatic;
DEFINES_MODULE = YES; DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = Z5M2UUN5YV; DEVELOPMENT_TEAM = A798K58VVL;
DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1; DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath"; DYLIB_INSTALL_NAME_BASE = "@rpath";
INFOPLIST_FILE = "paddle-mobile/Info.plist"; INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
"@loader_path/Frameworks", "@loader_path/Frameworks",
); );
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile"; PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES; SKIP_INSTALL = YES;
...@@ -487,17 +505,19 @@ ...@@ -487,17 +505,19 @@
CODE_SIGN_IDENTITY = ""; CODE_SIGN_IDENTITY = "";
CODE_SIGN_STYLE = Automatic; CODE_SIGN_STYLE = Automatic;
DEFINES_MODULE = YES; DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = Z5M2UUN5YV; DEVELOPMENT_TEAM = A798K58VVL;
DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1; DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath"; DYLIB_INSTALL_NAME_BASE = "@rpath";
INFOPLIST_FILE = "paddle-mobile/Info.plist"; INFOPLIST_FILE = "paddle-mobile/Info.plist";
INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks";
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = ( LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)", "$(inherited)",
"@executable_path/Frameworks", "@executable_path/Frameworks",
"@loader_path/Frameworks", "@loader_path/Frameworks",
); );
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile"; PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES; SKIP_INSTALL = YES;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key> <key>paddle-mobile.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>4</integer> <integer>3</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -20,4 +20,5 @@ public enum PaddleMobileError: Error{ ...@@ -20,4 +20,5 @@ public enum PaddleMobileError: Error{
case memoryError(message: String) case memoryError(message: String)
case paramError(message: String) case paramError(message: String)
case opError(message: String) case opError(message: String)
case predictError(message: String)
} }
//
// MetalExtension.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/2.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
fileprivate var defaultMetalLibrary: MTLLibrary?
fileprivate var paddleMobileMetalLibrary: MTLLibrary?
extension MTLDevice {
func defaultLibrary() -> MTLLibrary {
if defaultMetalLibrary == nil {
defaultMetalLibrary = makeDefaultLibrary()
}
if let inDefaultLib = defaultMetalLibrary {
return inDefaultLib
} else {
fatalError(" default metal libary is nil")
}
}
func paddleMobileLibrary() -> MTLLibrary {
if paddleMobileMetalLibrary == nil {
guard let path = Bundle.init(for: Kernel.self).path(forResource: "default", ofType: "metallib") else {
fatalError("Counld't find paddle mobile library")
}
do {
paddleMobileMetalLibrary = try makeLibrary(filepath: path)
} catch _ {
fatalError("Counld't load paddle mobile library")
}
paddleMobileMetalLibrary = makeDefaultLibrary()
}
if let inPaddleMobileLib = paddleMobileMetalLibrary {
return inPaddleMobileLib
} else {
fatalError("PaddleMobile metal libary is nil")
}
}
func pipeLine(funcName: String, inPaddleMobileLib: Bool = true) -> MTLComputePipelineState {
let useLib = inPaddleMobileLib ? paddleMobileLibrary() : defaultLibrary()
guard let function = useLib.makeFunction(name: funcName) else {
fatalError(" function " + funcName + " not found")
}
do {
let pipLine = try makeComputePipelineState(function: function)
return pipLine
} catch _ {
fatalError("make pip line error occured")
}
}
}
extension MTLComputeCommandEncoder {
func dispatch(computePipline: MTLComputePipelineState, outTexture: MTLTexture) {
let slices = (outTexture.depth + 3)/4
let width = computePipline.threadExecutionWidth
let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
setComputePipelineState(computePipline)
dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup)
}
}
...@@ -40,3 +40,10 @@ extension Texture: Variant { ...@@ -40,3 +40,10 @@ extension Texture: Variant {
extension ResultHolder: Variant { extension ResultHolder: Variant {
} }
extension InputTexture: Variant {
}
extension MTLTexture where Self: Variant {
}
...@@ -64,17 +64,23 @@ public class Executor<P: PrecisionType> { ...@@ -64,17 +64,23 @@ public class Executor<P: PrecisionType> {
} }
} }
public func predict(input: Texture) throws -> ResultHolder<P> { public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder<P> {
program.scope[program.feedKey] = input let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect))
program.scope.setInput(input: inputTexture)
for op in ops { for op in ops {
op.run() op.run()
} }
let outputVar = program.scope[program.fetchKey]
guard let outputVar = program.scope.output() else {
throw PaddleMobileError.netError(message: "output nil")
}
guard let output = outputVar as? ResultHolder<P> else { guard let output = outputVar as? ResultHolder<P> else {
throw PaddleMobileError.netError(message: "output var type error") throw PaddleMobileError.netError(message: "output var type error")
} }
return output return output
} }
} }
//public let paddle_executor: Executor = Executor.init() //public let paddle_executor: Executor = Executor.init()
...@@ -81,7 +81,7 @@ public class Loader<P: PrecisionType> { ...@@ -81,7 +81,7 @@ public class Loader<P: PrecisionType> {
} }
} }
public init(){} public init(){}
public func load(modelPath: String, paraPath: String) throws -> Program{ public func load(device: MTLDevice, modelPath: String, paraPath: String) throws -> Program{
guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else { guard let modelData = try? Data.init(contentsOf: URL.init(fileURLWithPath: modelPath)) else {
throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !") throw PaddleMobileError.loaderError(message: "load " + modelPath + " failed !")
} }
...@@ -89,7 +89,6 @@ public class Loader<P: PrecisionType> { ...@@ -89,7 +89,6 @@ public class Loader<P: PrecisionType> {
do { do {
let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
serializedData: modelData) serializedData: modelData)
let scope = Scope.init()
let programDesc = ProgramDesc.init(protoProgram: protoProgram) let programDesc = ProgramDesc.init(protoProgram: protoProgram)
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else { guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
...@@ -116,6 +115,8 @@ public class Loader<P: PrecisionType> { ...@@ -116,6 +115,8 @@ public class Loader<P: PrecisionType> {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found") throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
} }
let scope = Scope.init(inFeedKey: feedKey, inFetchKey: fetchKey)
// to load memory // to load memory
for block in programDesc.blocks { for block in programDesc.blocks {
for varDesc in block.vars { for varDesc in block.vars {
...@@ -148,19 +149,18 @@ public class Loader<P: PrecisionType> { ...@@ -148,19 +149,18 @@ public class Loader<P: PrecisionType> {
scope[varDesc.name] = tensor scope[varDesc.name] = tensor
} else { } else {
let dim = Dim.init(inDim: tensorDesc.NHWCDim) let dim = Dim.init(inDim: tensorDesc.NHWCDim)
scope[varDesc.name] = Texture.init(inDim: dim, inLayout: .NHWC) scope[varDesc.name] = Texture.init(device: device, inDim: dim)
} }
} else { } else {
if varDesc.name == fetchKey { if varDesc.name == fetchKey {
scope[varDesc.name] = ResultHolder<P>.init(inDim: [], inResult: []) scope[varDesc.name] = ResultHolder<P>.init(inDim: [], inResult: [])
} else if varDesc.name == feedKey { } else if varDesc.name == feedKey {
scope[varDesc.name] = Texture.init()
} }
} }
} }
} }
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey) let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope)
return program return program
} catch _ { } catch _ {
......
...@@ -23,11 +23,11 @@ import Foundation ...@@ -23,11 +23,11 @@ import Foundation
protocol OpParam { protocol OpParam {
associatedtype OutputType: Variant associatedtype OutputType: Variant
var output: OutputType { get } var output: OutputType { get set }
func outputDesc() -> String func outputDesc() -> String
associatedtype ParamPrecisionType: PrecisionType associatedtype ParamPrecisionType: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws init(opDesc: OpDesc, inScope: Scope) throws
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
......
...@@ -82,7 +82,7 @@ class Operator <ParameterType: OpParam>: OperatorProtocol{ ...@@ -82,7 +82,7 @@ class Operator <ParameterType: OpParam>: OperatorProtocol{
attrs = opDesc.attrs attrs = opDesc.attrs
paraInputs = opDesc.paraInputs paraInputs = opDesc.paraInputs
do { do {
para = try ParamType.init(opDesc:opDesc, scope: inScope) para = try ParamType.init(opDesc:opDesc, inScope: inScope)
} catch let error { } catch let error {
throw error throw error
} }
......
...@@ -16,14 +16,14 @@ import Foundation ...@@ -16,14 +16,14 @@ import Foundation
struct BatchNormParam<P: PrecisionType>: OpParam { struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope) inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope) inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope) inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: scope) inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs) epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs) momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs) is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs)
...@@ -32,7 +32,7 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -32,7 +32,7 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
} }
} }
let input: Texture let input: Texture
let output: Texture var output: Texture
let inputBias: Tensor<ParamPrecisionType> let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType> let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType> let inputScale: Tensor<ParamPrecisionType>
......
...@@ -16,11 +16,11 @@ import Foundation ...@@ -16,11 +16,11 @@ import Foundation
struct ConvParam<P: PrecisionType>: OpParam { struct ConvParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: scope) filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try ConvParam.input(inputs: opDesc.inputs, from: scope) input = try ConvParam.input(inputs: opDesc.inputs, from: inScope)
output = try ConvParam.output(outputs: opDesc.outputs, from: scope) output = try ConvParam.output(outputs: opDesc.outputs, from: inScope)
stride = try ConvParam.getAttr(key: "strides", attrs: opDesc.attrs) stride = try ConvParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs) paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs) dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
...@@ -31,7 +31,7 @@ struct ConvParam<P: PrecisionType>: OpParam { ...@@ -31,7 +31,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
} }
let input: Texture let input: Texture
let output: Texture var output: Texture
let filter: Tensor<ParamPrecisionType> let filter: Tensor<ParamPrecisionType>
let stride: [Int32] let stride: [Int32]
let paddings: [Int32] let paddings: [Int32]
......
...@@ -16,11 +16,11 @@ import Foundation ...@@ -16,11 +16,11 @@ import Foundation
struct ElementwiseAddParam<P: PrecisionType>: OpParam { struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
...@@ -28,7 +28,7 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -28,7 +28,7 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
} }
let input: Texture let input: Texture
let inputY: Tensor<P> let inputY: Tensor<P>
let output: Texture var output: Texture
let axis: Int let axis: Int
} }
......
...@@ -16,12 +16,15 @@ import Foundation ...@@ -16,12 +16,15 @@ import Foundation
struct FeedParam<P: PrecisionType>: OpParam{ struct FeedParam<P: PrecisionType>: OpParam{
var output: Texture var output: Texture
var input: Texture var input: InputTexture {
return scope.input() as! InputTexture
}
let scope: Scope
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do { do {
input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope) output = try FeedParam.outputOut(outputs: opDesc.outputs, from: inScope)
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -34,10 +37,15 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferS ...@@ -34,10 +37,15 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferS
typealias OpType = FeedOp<P> typealias OpType = FeedOp<P>
func inferShape() { func inferShape() {
// print("feed input: \(para.input.expectDim)")
print("feed output: \(para.output.dim)")
// para.ou/tput.dim = para.input.expectDim
} }
func runImpl() { func runImpl() {
print("feed op") print("feed op")
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>)
} }
} }
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
import Foundation import Foundation
struct FetchParam<P: PrecisionType>: OpParam{ struct FetchParam<P: PrecisionType>: OpParam{
let output: ResultHolder<P> var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: [])
let input: Texture let input: Texture
let scope: Scope
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do { do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope) input = try FetchParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -32,6 +32,7 @@ struct FetchParam<P: PrecisionType>: OpParam{ ...@@ -32,6 +32,7 @@ struct FetchParam<P: PrecisionType>: OpParam{
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{ class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
print(para.input.dim) print(para.input.dim)
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Metal
import Foundation
protocol Computable {
associatedtype ParamType
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
}
class Kernel {
let pipline: MTLComputePipelineState
let functionName: String
init(device: MTLDevice, inFunctionName: String) {
pipline = device.pipeLine(funcName: inFunctionName)
functionName = inFunctionName
}
}
//
// Kernels.metal
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
#include <metal_stdlib>
using namespace metal;
struct OutputDim {
ushort width;
ushort height;
ushort strideX;
ushort strideY;
};
kernel void resize(
texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) {
return;
}
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z, 0.0h), gid);
}
//
// ResizeKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/4.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
struct ResizeParam {
let input: MTLTexture
let output: MTLTexture
let expectDim: Dim
}
struct OutputDim {
let width: UInt16
let height: UInt16
let strideX: UInt16
let strideY: UInt16
}
class ResizeKernel: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1)
let strideX = param.input.width/param.expectDim[2]
let strideY = param.input.height/param.expectDim[1]
var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output)
encoder.endEncoding()
}
init(device: MTLDevice) {
super.init(device: device, inFunctionName: "resize")
}
}
...@@ -16,16 +16,16 @@ import Foundation ...@@ -16,16 +16,16 @@ import Foundation
struct ReluParam<P: PrecisionType>: OpParam { struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
do { do {
input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) input = try ReluParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) output = try ReluParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error { } catch let error {
throw error throw error
} }
} }
let input: Texture let input: Texture
let output: Texture var output: Texture
} }
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{ class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{
......
...@@ -16,15 +16,11 @@ import Foundation ...@@ -16,15 +16,11 @@ import Foundation
public struct Program { public struct Program {
let paramPath: String let paramPath: String
let feedKey: String
let fetchKey: String
let programDesc: ProgramDesc let programDesc: ProgramDesc
let scope: Scope let scope: Scope
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) { init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc) programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
paramPath = inParamPath paramPath = inParamPath
scope = inScope scope = inScope
feedKey = inFeedKey
fetchKey = inFetchKey
} }
} }
...@@ -15,6 +15,29 @@ ...@@ -15,6 +15,29 @@
import Foundation import Foundation
class Scope { class Scope {
let feedKey: String
let fetchKey: String
func setInput(input: Variant) {
vars[feedKey] = input
}
func setOutput(output: Variant) {
vars[fetchKey] = output
}
func input() -> Variant? {
return vars[feedKey];
}
func output() -> Variant? {
return vars[fetchKey];
}
init(inFeedKey: String, inFetchKey: String) {
feedKey = inFeedKey
fetchKey = inFetchKey
}
var vars: [String : Variant] = [:] var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{ subscript(key: String) -> Variant?{
get { get {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
public struct Dim { public struct Dim {
init(inDim: [Int]) { public init(inDim: [Int]) {
dims = inDim dims = inDim
} }
......
...@@ -18,10 +18,8 @@ protocol Tensorial: CustomStringConvertible, CustomDebugStringConvertible{ ...@@ -18,10 +18,8 @@ protocol Tensorial: CustomStringConvertible, CustomDebugStringConvertible{
var dim: Dim { get set } var dim: Dim { get set }
func numel() -> Int func numel() -> Int
var layout: DataLayout { get } var layout: DataLayout { get }
init(inDim: Dim, inLayout: DataLayout)
} }
extension Tensorial { extension Tensorial {
func numel() -> Int { func numel() -> Int {
return dim.numel() return dim.numel()
......
...@@ -15,29 +15,96 @@ ...@@ -15,29 +15,96 @@
import Metal import Metal
import Foundation import Foundation
class InputTexture {
let mtlTexture: MTLTexture
let expectDim: Dim
init(inMTLTexture: MTLTexture, inExpectDim: Dim) {
mtlTexture = inMTLTexture
expectDim = inExpectDim
}
}
extension InputTexture {
var description: String {
get{
return mtlTexture.description
}
}
var debugDescription: String {
get {
return mtlTexture.debugDescription ?? " MetalTexture "
}
}
}
public class Texture: Tensorial { public class Texture: Tensorial {
var dim: Dim var dim: Dim
let textureDesc: MTLTextureDescriptor
var metalTexture: MTLTexture
required public init(inDim: Dim, inLayout: DataLayout = .NHWC) { init(device: MTLDevice, inDim: Dim, inLayout: DataLayout = .NHWC) {
dim = inDim dim = inDim
layout = inLayout layout = inLayout
let tmpTextureDes = MTLTextureDescriptor.init()
if inDim.cout() == 1 {
tmpTextureDes.width = inDim[0]
tmpTextureDes.textureType = .type1D
} else if inDim.cout() == 4 {
tmpTextureDes.height = inDim[1]
tmpTextureDes.width = inDim[2]
// print("n : \(inDim[0])")
// print(inDim[3] * inDim[0])
tmpTextureDes.depth = 1
tmpTextureDes.arrayLength = (inDim[3] * inDim[0] + 3)/4
tmpTextureDes.textureType = .type2DArray
} else {
fatalError(" didn't support yet")
}
tmpTextureDes.pixelFormat = .r32Float
tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
} }
private(set) var layout: DataLayout // required public init(inDim: Dim, inLayout: DataLayout = .NHWC, inTexture: MTLTexture) {
// dim = inDim
// let texture: MTLTexture // layout = inLayout
// metalTexture = inTexture
public init(inTexture: MTLTexture, inDim: Dim) { // let tmpTextureDes = MTLTextureDescriptor.init()
// texture = inTexture //
dim = inDim // if inDim.cout() == 1 {
layout = .NHWC // tmpTextureDes.width = inDim[0]
} // tmpTextureDes.textureType = .type1D
// } else if inDim.cout() == 2 {
// tmpTextureDes.height = inDim[0]
// tmpTextureDes.width = inDim[1]
// tmpTextureDes.textureType = .type2D
// } else if inDim.cout() == 3 {
// fatalError(" not support texture dim 3")
// } else if inDim.cout() == 4 {
// tmpTextureDes.height = inDim[1]
// tmpTextureDes.width = inDim[2]
// tmpTextureDes.depth = inDim[3] * inDim[1]
// tmpTextureDes.textureType = .type2DArray
// }
//
// tmpTextureDes.pixelFormat = .r32Float
// tmpTextureDes.storageMode = .shared
// textureDesc = tmpTextureDes
// let device = MTLCreateSystemDefaultDevice()
// metalTexture = device!.makeTexture(descriptor: tmpTextureDes)!
// }
public init(inLayout: DataLayout = .NHWC) { // init() {
dim = Dim.init(inDim: []) // dim = Dim.init(inDim: [])
layout = inLayout // layout = .NCHW
} // let device = MTLCreateSystemDefaultDevice()
// textureDesc = MTLTextureDescriptor.init()
// metalTexture = device!.makeTexture(descriptor: textureDesc)!
// }
private(set) var layout: DataLayout
} }
extension Texture { extension Texture {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册