提交 64a171ca 编写于 作者: L liuruilong

add relu kernel

上级 b06a4e75
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key> <key>paddle-mobile-demo.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>4</integer> <integer>3</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -36,13 +36,13 @@ class ViewController: UIViewController { ...@@ -36,13 +36,13 @@ class ViewController: UIViewController {
fatalError(" texture is nil !") fatalError(" texture is nil !")
} }
let loader = Loader<Float>.init() let loader = Loader<Float16>.init()
do { do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath) let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(inProgram: program) let executor = try Executor<Float16>.init(inDevice: device, inQueue: queue!, inProgram: program)
let output = try executor.predict(input: inTexture, expect: [1, 224, 224, 3]) let output = try executor.predict(input: inTexture, expect: [1, 227, 227, 3])
print(output) print(output)
} catch let error { } catch let error {
print(error) print(error)
......
...@@ -30,6 +30,10 @@ ...@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; }; FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; }; FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; }; FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */; };
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */; };
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */; };
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; }; FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; }; FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; }; FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
...@@ -69,6 +73,10 @@ ...@@ -69,6 +73,10 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; }; FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; }; FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; }; FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReluKernel.swift; sourceTree = "<group>"; };
FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvKernel.swift; sourceTree = "<group>"; };
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BatchNormKernel.swift; sourceTree = "<group>"; };
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; }; FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; }; FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
...@@ -197,9 +205,13 @@ ...@@ -197,9 +205,13 @@
FC086BA520E67E8500D85EF7 /* Kernels */ = { FC086BA520E67E8500D85EF7 /* Kernels */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */,
FCF2D73720E64E70007AC5F5 /* Kernel.swift */, FCF2D73720E64E70007AC5F5 /* Kernel.swift */,
FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */,
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */, FC1B186520ECF1C600678B91 /* ResizeKernel.swift */,
FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */,
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */,
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -316,12 +328,14 @@ ...@@ -316,12 +328,14 @@
files = ( files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */, FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */, FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */, FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */, FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
...@@ -335,7 +349,9 @@ ...@@ -335,7 +349,9 @@
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key> <key>paddle-mobile.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>3</integer> <integer>4</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -29,11 +29,11 @@ extension MTLDevice { ...@@ -29,11 +29,11 @@ extension MTLDevice {
fatalError("Counld't find paddle mobile library") fatalError("Counld't find paddle mobile library")
} }
do { do {
print(path)
paddleMobileMetalLibrary = try makeLibrary(filepath: path) paddleMobileMetalLibrary = try makeLibrary(filepath: path)
} catch _ { } catch _ {
fatalError("Counld't load paddle mobile library") fatalError("Counld't load paddle mobile library")
} }
paddleMobileMetalLibrary = makeDefaultLibrary()
} }
if let inPaddleMobileLib = paddleMobileMetalLibrary { if let inPaddleMobileLib = paddleMobileMetalLibrary {
...@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder { ...@@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder {
let height = computePipline.maxTotalThreadsPerThreadgroup/width let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1) let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
print(" threads per group: \(threadsPerGroup) ")
print(" out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
let groupWidth = (outTexture.width + width - 1)/width let groupWidth = (outTexture.width + width - 1)/width
let groupHeight = (outTexture.height + height - 1)/height let groupHeight = (outTexture.height + height - 1)/height
let groupDepth = slices let groupDepth = slices
let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth) let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth)
print("groups: \(groups) ")
setComputePipelineState(computePipline) setComputePipelineState(computePipline)
dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup) dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup)
} }
......
...@@ -14,14 +14,22 @@ ...@@ -14,14 +14,22 @@
import Foundation import Foundation
//typealias Float16 = Int16 public typealias Float16 = Int16
//extension Float16: PrecisionType { extension Float16: PrecisionType {
//} public init(inFloat: Float32) {
self = Int16(inFloat)
}
}
public protocol PrecisionType { public protocol PrecisionType {
init(inFloat: Float32)
} }
extension Float32: PrecisionType { extension Float32: PrecisionType {
public init(inFloat: Float32) {
self = inFloat
}
} }
public enum DataLayout { public enum DataLayout {
......
...@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible { ...@@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public class Executor<P: PrecisionType> { public class Executor<P: PrecisionType> {
var ops: [Runable & InferShaperable] = [] var ops: [Runable & InferShaperable] = []
let program: Program let program: Program
let device: MTLDevice
public init(inProgram: Program) throws { let queue: MTLCommandQueue
public init(inDevice:MTLDevice, inQueue: MTLCommandQueue, inProgram: Program) throws {
program = inProgram program = inProgram
device = inDevice
queue = inQueue
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
for op in block.ops { for op in block.ops {
do { do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: inProgram.scope) let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
op.inferShape() op.inferShape()
ops.append(op) ops.append(op)
} catch let error { } catch let error {
...@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> { ...@@ -65,12 +68,29 @@ public class Executor<P: PrecisionType> {
} }
public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder<P> { public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder<P> {
let beforeDate = Date.init()
let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect)) let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect))
program.scope.setInput(input: inputTexture) program.scope.setInput(input: inputTexture)
guard let buffer = queue.makeCommandBuffer() else {
throw PaddleMobileError.predictError(message: "CommandBuffer is nil")
}
for op in ops { for op in ops {
op.run() do {
try op.run(device: device, buffer: buffer)
} catch let error {
throw error
}
}
buffer.addCompletedHandler { (commandbuffer) in
let afterDate = Date.init()
print(afterDate.timeIntervalSince(beforeDate))
print(" encoder end ! ")
} }
buffer.commit()
guard let outputVar = program.scope.output() else { guard let outputVar = program.scope.output() else {
throw PaddleMobileError.netError(message: "output nil") throw PaddleMobileError.netError(message: "output nil")
} }
...@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> { ...@@ -78,6 +98,8 @@ public class Executor<P: PrecisionType> {
guard let output = outputVar as? ResultHolder<P> else { guard let output = outputVar as? ResultHolder<P> else {
throw PaddleMobileError.netError(message: "output var type error") throw PaddleMobileError.netError(message: "output var type error")
} }
return output return output
} }
......
...@@ -69,10 +69,23 @@ public class Loader<P: PrecisionType> { ...@@ -69,10 +69,23 @@ public class Loader<P: PrecisionType> {
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度 这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
*/ */
let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file) //现在模型传入模型为 Float 类型, 这块应该根据模型来
guard bytesRead == tensor.data.size else { let tmpCapacity = MemoryLayout<Float>.size * tensor.numel()
throw PaddleMobileError.loaderError(message: "param read size error") let tmpPointer = UnsafeMutablePointer<Float>.allocate(capacity: tmpCapacity);
// let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
// guard bytesRead == tensor.data.size else {
// throw PaddleMobileError.loaderError(message: "param read size error")
// }
// TODO: use script to convert
let bytesRead = fread(tmpPointer, 1, tmpCapacity, file)
for i in 0..<tensor.numel() {
tensor.data[i] = P.init(inFloat: tmpPointer[i])
} }
tmpPointer.deinitialize(count: tmpCapacity)
tmpPointer.deallocate()
nowIndex += bytesRead nowIndex += bytesRead
} }
...@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> { ...@@ -125,9 +138,9 @@ public class Loader<P: PrecisionType> {
throw PaddleMobileError.loaderError(message: "get tensor desc failed") throw PaddleMobileError.loaderError(message: "get tensor desc failed")
} }
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else { // guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
throw PaddleMobileError.memoryError(message: "PrecisionType not support") // throw PaddleMobileError.memoryError(message: "PrecisionType not support")
} // }
if (varDesc.persistable if (varDesc.persistable
&& varDesc.type != .FeedMiniBatch && varDesc.type != .FeedMiniBatch
...@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> { ...@@ -149,7 +162,7 @@ public class Loader<P: PrecisionType> {
scope[varDesc.name] = tensor scope[varDesc.name] = tensor
} else { } else {
let dim = Dim.init(inDim: tensorDesc.NHWCDim) let dim = Dim.init(inDim: tensorDesc.NHWCDim)
scope[varDesc.name] = Texture.init(device: device, inDim: dim) scope[varDesc.name] = Texture<P>.init(device: device, inDim: dim)
} }
} else { } else {
if varDesc.name == fetchKey { if varDesc.name == fetchKey {
......
...@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> { ...@@ -27,19 +27,19 @@ class OpCreator<P: PrecisionType> {
} }
} }
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { func creat(device: MTLDevice, opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else { guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
} }
do { do {
return try opCreator(opDesc, scope) return try opCreator(device, opDesc, scope)
} catch let error { } catch let error {
throw error throw error
} }
} }
let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] = let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat, [gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat, gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat, gReluType : ReluOp<P>.creat,
......
...@@ -12,29 +12,35 @@ ...@@ -12,29 +12,35 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
import Metal
import Foundation import Foundation
protocol Runable { protocol Runable {
func run() func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl() func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
} }
extension Runable where Self: OperatorProtocol{ extension Runable where Self: OperatorProtocol{
func run() { func run(device: MTLDevice, buffer: MTLCommandBuffer) throws {
runImpl() do {
try runImpl(device: device, buffer: buffer)
} catch let error {
throw error
}
print(type + ": " + para.outputDesc()) print(type + ": " + para.outputDesc())
} }
} }
protocol Creator where Self: OperatorProtocol{ protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol & Runable & InferShaperable associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType
} }
extension Creator where Self: OperatorProtocol { extension Creator where Self: OperatorProtocol {
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType { static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType {
do { do {
return try OpType.provide(opDesc: opDesc, inScope: inScope) return try OpType.provide(device:device, opDesc: opDesc, inScope: inScope)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -47,19 +53,21 @@ protocol InferShaperable { ...@@ -47,19 +53,21 @@ protocol InferShaperable {
protocol OperatorProtocol { protocol OperatorProtocol {
associatedtype ParamType: OpParam associatedtype ParamType: OpParam
associatedtype KerType: Computable
var type: String { get } var type: String { get }
var inputs: [String : [String]] { get } var inputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get } var paraInputs: [String : [String]] { get }
var outpus: [String : [String]] { get } var outpus: [String : [String]] { get }
var attrs: [String : Attr] { get } var attrs: [String : Attr] { get }
var para: ParamType { get } var para: ParamType { get }
init(opDesc: OpDesc, inScope: Scope) throws var kernel: KerType { get }
init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws
} }
extension OperatorProtocol { extension OperatorProtocol {
static func provide(opDesc: OpDesc, inScope: Scope) throws -> Self { static func provide(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> Self {
do { do {
return try Self.init(opDesc: opDesc, inScope: inScope) return try Self.init(device: device, opDesc: opDesc, inScope: inScope)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -67,20 +75,23 @@ extension OperatorProtocol { ...@@ -67,20 +75,23 @@ extension OperatorProtocol {
} }
class Operator <ParameterType: OpParam>: OperatorProtocol{ class Operator <ParameterType: OpParam, KernelType: Computable>: OperatorProtocol{
typealias ParamType = ParameterType typealias ParamType = ParameterType
typealias KerType = KernelType
let type: String let type: String
let inputs: [String : [String]] let inputs: [String : [String]]
let paraInputs: [String : [String]] let paraInputs: [String : [String]]
let outpus: [String : [String]] let outpus: [String : [String]]
let attrs: [String : Attr] let attrs: [String : Attr]
let para: ParamType let para: ParamType
required init(opDesc: OpDesc, inScope: Scope) throws { var kernel: KerType
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
type = opDesc.type type = opDesc.type
inputs = opDesc.inputs inputs = opDesc.inputs
outpus = opDesc.outputs outpus = opDesc.outputs
attrs = opDesc.attrs attrs = opDesc.attrs
paraInputs = opDesc.paraInputs paraInputs = opDesc.paraInputs
kernel = KerType.init(device: device)
do { do {
para = try ParamType.init(opDesc:opDesc, inScope: inScope) para = try ParamType.init(opDesc:opDesc, inScope: inScope)
} catch let error { } catch let error {
......
...@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw error throw error
} }
} }
let input: Texture let input: Texture<P>
var output: Texture var output: Texture<P>
let inputBias: Tensor<ParamPrecisionType> let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType> let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType> let inputScale: Tensor<ParamPrecisionType>
...@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -42,12 +42,12 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let is_test: Bool let is_test: Bool
} }
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator, InferShaperable{ class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>, BatchNormKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim para.output.dim = para.input.dim
} }
typealias OpType = BatchNormOp<P> typealias OpType = BatchNormOp<P>
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is BatchNormOp") print("this is BatchNormOp")
} }
} }
......
...@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam { ...@@ -30,8 +30,8 @@ struct ConvParam<P: PrecisionType>: OpParam {
} }
} }
let input: Texture let input: Texture<P>
var output: Texture var output: Texture<P>
let filter: Tensor<ParamPrecisionType> let filter: Tensor<ParamPrecisionType>
let stride: [Int32] let stride: [Int32]
let paddings: [Int32] let paddings: [Int32]
...@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam { ...@@ -39,7 +39,7 @@ struct ConvParam<P: PrecisionType>: OpParam {
let groups: Int let groups: Int
} }
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable { class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable, Creator, InferShaperable {
func inferShape() { func inferShape() {
let inDims = para.input.dim let inDims = para.input.dim
let filterDim = para.filter.dim let filterDim = para.filter.dim
...@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS ...@@ -63,7 +63,7 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferS
} }
typealias OpType = ConvOp<P> typealias OpType = ConvOp<P>
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is conv") print("this is conv")
} }
} }
...@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -26,20 +26,20 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
throw error throw error
} }
} }
let input: Texture let input: Texture<P>
let inputY: Tensor<P> let inputY: Tensor<P>
var output: Texture var output: Texture<P>
let axis: Int let axis: Int
} }
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>, ElementwiseAddKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim para.output.dim = para.input.dim
} }
typealias OpType = ElementwiseAddOp<P> typealias OpType = ElementwiseAddOp<P>
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is ElementwiseAddOp") print("this is ElementwiseAddOp")
} }
} }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import Foundation import Foundation
struct FeedParam<P: PrecisionType>: OpParam{ struct FeedParam<P: PrecisionType>: OpParam{
var output: Texture var output: Texture<P>
var input: InputTexture { var input: InputTexture {
return scope.input() as! InputTexture return scope.input() as! InputTexture
} }
...@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{ ...@@ -33,19 +33,26 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
} }
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferShaperable { class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P> typealias OpType = FeedOp<P>
func inferShape() { func inferShape() {
// print("feed input: \(para.input.expectDim)") // print("feed input: \(para.input.expectDim)")
print("feed output: \(para.output.dim)") print("feed output: \(para.output.dim)")
// para.output.dim =
// para.ou/tput.dim = para.input.expectDim // para.output.dim = para.input.expectDim
} }
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("feed op") let resizeKernel = ResizeKernel<P>.init(device: device)
// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>) let resizeParam = ResizeParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
do {
print("feed op to compute ")
try resizeKernel.compute(commandBuffer: buffer, param: resizeParam)
print("feed op end compute ")
} catch let error {
throw error
}
} }
} }
...@@ -16,7 +16,7 @@ import Foundation ...@@ -16,7 +16,7 @@ import Foundation
struct FetchParam<P: PrecisionType>: OpParam{ struct FetchParam<P: PrecisionType>: OpParam{
var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: []) var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: [])
let input: Texture let input: Texture<P>
let scope: Scope let scope: Scope
init(opDesc: OpDesc, inScope: Scope) throws { init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope scope = inScope
...@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{ ...@@ -30,14 +30,14 @@ struct FetchParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
} }
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{ class FetchOp<P: PrecisionType>: Operator<FetchParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
print(para.input.dim) print(para.input.dim)
} }
typealias OpType = FetchOp<P> typealias OpType = FetchOp<P>
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("fetch op") print("fetch op")
} }
} }
......
//
// BatchNormKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "batchnorm")
}
func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam<P>) throws {
}
}
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ConvKernel<P: PrecisionType>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
}
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
}
}
//
// ElementwiseAddKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
}
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
}
}
...@@ -18,6 +18,12 @@ import Foundation ...@@ -18,6 +18,12 @@ import Foundation
protocol Computable { protocol Computable {
associatedtype ParamType associatedtype ParamType
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice)
}
protocol KernelProtocol {
var pipline: MTLComputePipelineState { get set }
var functionName: String { get set }
} }
class Kernel { class Kernel {
......
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Kernels.metal
// paddle-mobile Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/7/4. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
...@@ -16,19 +22,70 @@ struct OutputDim { ...@@ -16,19 +22,70 @@ struct OutputDim {
ushort strideY; ushort strideY;
}; };
kernel void resize( kernel void resize(texture2d<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::read> inTexture [[texture(0)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
texture2d<half, access::write> outTexture [[texture(1)]],
constant OutputDim &params [[buffer(0)]], constant OutputDim &params [[buffer(0)]],
uint2 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) { gid.y >= outTexture.get_height() ||
return; gid.z >= outTexture.get_array_size()) return;
}
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint2 pos = gid.xy * uint2(params.strideX, params.strideY); const uint2 pos = gid.xy * uint2(params.strideX, params.strideY);
const half4 input = inTexture.read(pos); const half4 input = inTexture.read(pos);
outTexture.write(half4(input.x, input.y, input.z, 0.0h), gid); outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z);
}
kernel void relu(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float4 relu = fmax((float4)input, 0.0);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
} }
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device half4 *biasTerms [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ReluKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ReluParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
print(" the usage of input of relu \(param.input.metalTexture.usage)")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "relu")
}
}
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// ResizeKernel.swift
// paddle-mobile Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/7/4. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation import Foundation
...@@ -22,7 +28,7 @@ struct OutputDim { ...@@ -22,7 +28,7 @@ struct OutputDim {
let strideY: UInt16 let strideY: UInt16
} }
class ResizeKernel: Kernel, Computable{ class ResizeKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws { func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else { guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil") throw PaddleMobileError.predictError(message: " encode is nil")
...@@ -30,7 +36,6 @@ class ResizeKernel: Kernel, Computable{ ...@@ -30,7 +36,6 @@ class ResizeKernel: Kernel, Computable{
encoder.setTexture(param.input, index: 0) encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1) encoder.setTexture(param.output, index: 1)
let strideX = param.input.width/param.expectDim[2] let strideX = param.input.width/param.expectDim[2]
let strideY = param.input.height/param.expectDim[1] let strideY = param.input.height/param.expectDim[1]
var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY)) var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
...@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{ ...@@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{
encoder.endEncoding() encoder.endEncoding()
} }
init(device: MTLDevice) { required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "resize") super.init(device: device, inFunctionName: "resize")
} }
} }
......
...@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam { ...@@ -24,19 +24,23 @@ struct ReluParam<P: PrecisionType>: OpParam {
throw error throw error
} }
} }
let input: Texture let input: Texture<P>
var output: Texture var output: Texture<P>
} }
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{ class ReluOp<P: PrecisionType>: Operator<ReluParam<P>, ReluKernel<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim para.output.dim = para.input.dim
} }
typealias OpType = ReluOp<P> typealias OpType = ReluOp<P>
func runImpl() { func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
print("this is ReluOp") do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
} }
} }
......
...@@ -38,7 +38,7 @@ extension InputTexture { ...@@ -38,7 +38,7 @@ extension InputTexture {
} }
} }
public class Texture: Tensorial { public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim var dim: Dim
let textureDesc: MTLTextureDescriptor let textureDesc: MTLTextureDescriptor
var metalTexture: MTLTexture var metalTexture: MTLTexture
...@@ -61,7 +61,15 @@ public class Texture: Tensorial { ...@@ -61,7 +61,15 @@ public class Texture: Tensorial {
} else { } else {
fatalError(" didn't support yet") fatalError(" didn't support yet")
} }
if MemoryLayout<P>.size == 1 {
tmpTextureDes.pixelFormat = .r8Sint
} else if MemoryLayout<P>.size == 2 {
tmpTextureDes.pixelFormat = .r16Float
} else if MemoryLayout<P>.size == 4 {
tmpTextureDes.pixelFormat = .r32Float tmpTextureDes.pixelFormat = .r32Float
}
tmpTextureDes.usage = .unknown
tmpTextureDes.storageMode = .shared tmpTextureDes.storageMode = .shared
textureDesc = tmpTextureDes textureDesc = tmpTextureDes
metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil " metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册