diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 125fd5ec745f635929df2d4e29f8147ab9a6b83a..e2c6b20680bd2f1209dfea198f9f62f2171716cf 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile-demo.xcscheme orderHint - 4 + 3 diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift index 5edacc34578637b76e3a09f9cbce7213cbac3e54..1396e95f4e61b85627e1d9349eb24e91be7e6d6a 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift @@ -36,13 +36,13 @@ class ViewController: UIViewController { fatalError(" texture is nil !") } - let loader = Loader.init() + let loader = Loader.init() do { let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let program = try loader.load(device: device, modelPath: modelPath, paraPath: paraPath) - let executor = try Executor.init(inProgram: program) - let output = try executor.predict(input: inTexture, expect: [1, 224, 224, 3]) + let executor = try Executor.init(inDevice: device, inQueue: queue!, inProgram: program) + let output = try executor.predict(input: inTexture, expect: [1, 227, 227, 3]) print(output) } catch let error { print(error) diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index c20bfa4f839b35f75f44d9393da9f80e1a8627b5..5234f24ddb107c5d809e873eafa8ad297e4cdb95 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -30,6 +30,10 @@ FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; }; FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; }; FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; }; + FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */; }; + FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */; }; + FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */; }; + FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; }; FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; }; FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; }; FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; }; @@ -69,6 +73,10 @@ FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = ""; }; FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = ""; }; FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = ""; }; + FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReluKernel.swift; sourceTree = ""; }; + FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvKernel.swift; sourceTree = ""; }; + FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BatchNormKernel.swift; sourceTree = ""; }; + FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = ""; }; FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = ""; }; FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = ""; }; FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = ""; }; @@ -197,9 +205,13 @@ FC086BA520E67E8500D85EF7 /* Kernels */ = { isa = PBXGroup; children = ( + FC0E2DBB20EE45FE009C1FAC /* ConvKernel.swift */, FCF2D73720E64E70007AC5F5 /* Kernel.swift */, FC1B16B220EC9A4F00678B91 /* Kernels.metal */, FC1B186520ECF1C600678B91 /* ResizeKernel.swift */, + FC0E2DB920EE3B8D009C1FAC /* ReluKernel.swift */, + FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */, + FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */, ); path = Kernels; sourceTree = ""; @@ -316,12 +328,14 @@ files = ( FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, + FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC9D037920E229E4000F735A /* OpParam.swift in Sources */, FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */, FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */, + FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, @@ -335,7 +349,9 @@ FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, + FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */, FC82735920E3C04200BE430A /* OpCreator.swift in Sources */, + FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 52725213607bd68fccbd4aa63faa9294cb622962..50f16e4d7cfc755905c68ced31769d3ac1dd1049 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile.xcscheme orderHint - 3 + 4 diff --git a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift index 66166b8388ebab9d9196b5e78d0765bdc3a91877..62dfdfa9f96e631f53cc093611217e7e9b5b92df 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/MetalExtension.swift @@ -29,11 +29,11 @@ extension MTLDevice { fatalError("Counld't find paddle mobile library") } do { + print(path) paddleMobileMetalLibrary = try makeLibrary(filepath: path) } catch _ { fatalError("Counld't load paddle mobile library") } - paddleMobileMetalLibrary = makeDefaultLibrary() } if let inPaddleMobileLib = paddleMobileMetalLibrary { @@ -67,11 +67,17 @@ extension MTLComputeCommandEncoder { let height = computePipline.maxTotalThreadsPerThreadgroup/width let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1) + print(" threads per group: \(threadsPerGroup) ") + + print(" out texture width: \(outTexture.width) , out texture height: \(outTexture.height)") + let groupWidth = (outTexture.width + width - 1)/width let groupHeight = (outTexture.height + height - 1)/height let groupDepth = slices let groups = MTLSize.init(width: groupWidth, height: groupHeight, depth: groupDepth) + print("groups: \(groups) ") + setComputePipelineState(computePipline) dispatchThreadgroups(groups, threadsPerThreadgroup: threadsPerGroup) } diff --git a/metal/paddle-mobile/paddle-mobile/Common/Types.swift b/metal/paddle-mobile/paddle-mobile/Common/Types.swift index f0d7c194d696942e03a3a991fe002a08899ee62f..f910d336281e41a8af2d778e9e9f44dcc5fe2e33 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Types.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Types.swift @@ -14,14 +14,22 @@ import Foundation -//typealias Float16 = Int16 -//extension Float16: PrecisionType { -//} +public typealias Float16 = Int16 +extension Float16: PrecisionType { + public init(inFloat: Float32) { + self = Int16(inFloat) + } +} public protocol PrecisionType { + init(inFloat: Float32) } + extension Float32: PrecisionType { + public init(inFloat: Float32) { + self = inFloat + } } public enum DataLayout { diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 4baca148b118069e3a0f8c1479ca1507899d385e..6f190242f231efb22f90ae26df533529824b2c04 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -48,13 +48,16 @@ extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible { public class Executor { var ops: [Runable & InferShaperable] = [] let program: Program - - public init(inProgram: Program) throws { + let device: MTLDevice + let queue: MTLCommandQueue + public init(inDevice:MTLDevice, inQueue: MTLCommandQueue, inProgram: Program) throws { program = inProgram + device = inDevice + queue = inQueue for block in inProgram.programDesc.blocks { for op in block.ops { do { - let op = try OpCreator

.shared.creat(opDesc: op, scope: inProgram.scope) + let op = try OpCreator

.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope) op.inferShape() ops.append(op) } catch let error { @@ -65,12 +68,29 @@ public class Executor { } public func predict(input: MTLTexture, expect: [Int]) throws -> ResultHolder

{ + let beforeDate = Date.init() let inputTexture = InputTexture.init(inMTLTexture: input, inExpectDim: Dim.init(inDim: expect)) program.scope.setInput(input: inputTexture) + guard let buffer = queue.makeCommandBuffer() else { + throw PaddleMobileError.predictError(message: "CommandBuffer is nil") + } + for op in ops { - op.run() + do { + try op.run(device: device, buffer: buffer) + } catch let error { + throw error + } } + buffer.addCompletedHandler { (commandbuffer) in + let afterDate = Date.init() + print(afterDate.timeIntervalSince(beforeDate)) + print(" encoder end ! ") + } + + buffer.commit() + guard let outputVar = program.scope.output() else { throw PaddleMobileError.netError(message: "output nil") } @@ -78,6 +98,8 @@ public class Executor { guard let output = outputVar as? ResultHolder

else { throw PaddleMobileError.netError(message: "output var type error") } + + return output } diff --git a/metal/paddle-mobile/paddle-mobile/Loader.swift b/metal/paddle-mobile/paddle-mobile/Loader.swift index 8bec6f6ebc76587ff047768d782647d0f9ac51e1..4ef1e1b0720561d45d8bceaa93d4cbe6b09074d0 100644 --- a/metal/paddle-mobile/paddle-mobile/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Loader.swift @@ -68,11 +68,24 @@ public class Loader { /* 这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度 */ - - let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file) - guard bytesRead == tensor.data.size else { - throw PaddleMobileError.loaderError(message: "param read size error") + + //现在模型传入模型为 Float 类型, 这块应该根据模型来 + let tmpCapacity = MemoryLayout.size * tensor.numel() + let tmpPointer = UnsafeMutablePointer.allocate(capacity: tmpCapacity); + +// let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file) +// guard bytesRead == tensor.data.size else { +// throw PaddleMobileError.loaderError(message: "param read size error") +// } + + // TODO: use script to convert + let bytesRead = fread(tmpPointer, 1, tmpCapacity, file) + for i in 0.. { throw PaddleMobileError.loaderError(message: "get tensor desc failed") } - guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout

.size else { - throw PaddleMobileError.memoryError(message: "PrecisionType not support") - } +// guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout

.size else { +// throw PaddleMobileError.memoryError(message: "PrecisionType not support") +// } if (varDesc.persistable && varDesc.type != .FeedMiniBatch @@ -149,7 +162,7 @@ public class Loader { scope[varDesc.name] = tensor } else { let dim = Dim.init(inDim: tensorDesc.NHWCDim) - scope[varDesc.name] = Texture.init(device: device, inDim: dim) + scope[varDesc.name] = Texture

.init(device: device, inDim: dim) } } else { if varDesc.name == fetchKey { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift index dfe5be4240ba48dadbb144dc95015c7ab318051f..579526d7e7e4a2f6570b97553eb0c3d3c6d7530c 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpCreator.swift @@ -27,19 +27,19 @@ class OpCreator { } } - func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { + func creat(device: MTLDevice, opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { guard let opCreator = opCreators[opDesc.type] else { throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") } do { - return try opCreator(opDesc, scope) + return try opCreator(device, opDesc, scope) } catch let error { throw error } } - let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] = + let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] = [gConvType : ConvOp

.creat, gBatchNormType : BatchNormOp

.creat, gReluType : ReluOp

.creat, diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index d74d5ff93ce165313394b452679f1fb72593f898..222c5319c6d7dfd88d8a83140bef98e67dcb2fd5 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -12,29 +12,35 @@ See the License for the specific language governing permissions and limitations under the License. */ +import Metal import Foundation protocol Runable { - func run() - func runImpl() + func run(device: MTLDevice, buffer: MTLCommandBuffer) throws + func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws } extension Runable where Self: OperatorProtocol{ - func run() { - runImpl() + func run(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try runImpl(device: device, buffer: buffer) + } catch let error { + throw error + } + print(type + ": " + para.outputDesc()) } } protocol Creator where Self: OperatorProtocol{ associatedtype OpType: OperatorProtocol & Runable & InferShaperable - static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType + static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType } extension Creator where Self: OperatorProtocol { - static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType { + static func creat(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> OpType { do { - return try OpType.provide(opDesc: opDesc, inScope: inScope) + return try OpType.provide(device:device, opDesc: opDesc, inScope: inScope) } catch let error { throw error } @@ -47,19 +53,21 @@ protocol InferShaperable { protocol OperatorProtocol { associatedtype ParamType: OpParam + associatedtype KerType: Computable var type: String { get } var inputs: [String : [String]] { get } var paraInputs: [String : [String]] { get } var outpus: [String : [String]] { get } var attrs: [String : Attr] { get } var para: ParamType { get } - init(opDesc: OpDesc, inScope: Scope) throws + var kernel: KerType { get } + init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws } extension OperatorProtocol { - static func provide(opDesc: OpDesc, inScope: Scope) throws -> Self { + static func provide(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws -> Self { do { - return try Self.init(opDesc: opDesc, inScope: inScope) + return try Self.init(device: device, opDesc: opDesc, inScope: inScope) } catch let error { throw error } @@ -67,20 +75,23 @@ extension OperatorProtocol { } -class Operator : OperatorProtocol{ - typealias ParamType = ParameterType +class Operator : OperatorProtocol{ + typealias ParamType = ParameterType + typealias KerType = KernelType let type: String let inputs: [String : [String]] let paraInputs: [String : [String]] let outpus: [String : [String]] let attrs: [String : Attr] let para: ParamType - required init(opDesc: OpDesc, inScope: Scope) throws { + var kernel: KerType + required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws { type = opDesc.type inputs = opDesc.inputs outpus = opDesc.outputs attrs = opDesc.attrs paraInputs = opDesc.paraInputs + kernel = KerType.init(device: device) do { para = try ParamType.init(opDesc:opDesc, inScope: inScope) } catch let error { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 9c6fb71fa17d799c11df6abbe9a9fee894e8865e..150004f2ce8c718092d255226b8e1c497fc4bba0 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -31,8 +31,8 @@ struct BatchNormParam: OpParam { throw error } } - let input: Texture - var output: Texture + let input: Texture

+ var output: Texture

let inputBias: Tensor let inputMean: Tensor let inputScale: Tensor @@ -42,12 +42,12 @@ struct BatchNormParam: OpParam { let is_test: Bool } -class BatchNormOp: Operator>, Runable, Creator, InferShaperable{ +class BatchNormOp: Operator, BatchNormKernel

>, Runable, Creator, InferShaperable{ func inferShape() { para.output.dim = para.input.dim } typealias OpType = BatchNormOp

- func runImpl() { + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { print("this is BatchNormOp") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift index 3b1dcfc47050c79d6053bc0463272c0a88665c45..b4623bf567cbcb17c99b28e246d3bf0d242bc23b 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift @@ -30,8 +30,8 @@ struct ConvParam: OpParam { } } - let input: Texture - var output: Texture + let input: Texture

+ var output: Texture

let filter: Tensor let stride: [Int32] let paddings: [Int32] @@ -39,7 +39,7 @@ struct ConvParam: OpParam { let groups: Int } -class ConvOp: Operator>, Runable, Creator, InferShaperable { +class ConvOp: Operator, ConvKernel

>, Runable, Creator, InferShaperable { func inferShape() { let inDims = para.input.dim let filterDim = para.filter.dim @@ -63,7 +63,7 @@ class ConvOp: Operator>, Runable, Creator, InferS } typealias OpType = ConvOp

- func runImpl() { + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { print("this is conv") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift index f0966c5b4d8ee358443fdbdf32d99d3e741aba1f..12c0b832b76b0975d4b5a138c283866166e06130 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift @@ -26,20 +26,20 @@ struct ElementwiseAddParam: OpParam { throw error } } - let input: Texture + let input: Texture

let inputY: Tensor

- var output: Texture + var output: Texture

let axis: Int } -class ElementwiseAddOp: Operator>, Runable, Creator, InferShaperable{ +class ElementwiseAddOp: Operator, ElementwiseAddKernel

>, Runable, Creator, InferShaperable{ func inferShape() { para.output.dim = para.input.dim } typealias OpType = ElementwiseAddOp

- func runImpl() { + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { print("this is ElementwiseAddOp") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift index 9055fcd7e7b6a97d6e291d5086e5fd10855e2d82..87b71a7a95d316cd226e4e648506fbe011d5e4c8 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift @@ -15,7 +15,7 @@ import Foundation struct FeedParam: OpParam{ - var output: Texture + var output: Texture

var input: InputTexture { return scope.input() as! InputTexture } @@ -33,19 +33,26 @@ struct FeedParam: OpParam{ typealias ParamPrecisionType = P } -class FeedOp: Operator>, Runable, Creator, InferShaperable { +class FeedOp: Operator, ResizeKernel

>, Runable, Creator, InferShaperable { typealias OpType = FeedOp

func inferShape() { // print("feed input: \(para.input.expectDim)") print("feed output: \(para.output.dim)") - -// para.ou/tput.dim = para.input.expectDim + // para.output.dim = +// para.output.dim = para.input.expectDim } - func runImpl() { - print("feed op") -// let resizeKernel = ResizeKernel.init(device: <#T##MTLDevice#>) + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + let resizeKernel = ResizeKernel

.init(device: device) + let resizeParam = ResizeParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim) + do { + print("feed op to compute ") + try resizeKernel.compute(commandBuffer: buffer, param: resizeParam) + print("feed op end compute ") + } catch let error { + throw error + } } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift index 6e7024909a65475fd1870d7bae01894dbe7f0395..b79538fe7770dd818eb7399547d2621249938c89 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift @@ -16,7 +16,7 @@ import Foundation struct FetchParam: OpParam{ var output: ResultHolder

= ResultHolder.init(inDim: [], inResult: []) - let input: Texture + let input: Texture

let scope: Scope init(opDesc: OpDesc, inScope: Scope) throws { scope = inScope @@ -30,14 +30,14 @@ struct FetchParam: OpParam{ typealias ParamPrecisionType = P } -class FetchOp: Operator>, Runable, Creator, InferShaperable{ +class FetchOp: Operator, ResizeKernel

>, Runable, Creator, InferShaperable{ func inferShape() { print(para.input.dim) } typealias OpType = FetchOp

- func runImpl() { + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { print("fetch op") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..b525b9bf61a2e66f6e417076eea813d790fe909c --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/BatchNormKernel.swift @@ -0,0 +1,19 @@ +// +// BatchNormKernel.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/5. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + +class BatchNormKernel: Kernel, Computable { + required init(device: MTLDevice) { + super.init(device: device, inFunctionName: "batchnorm") + } + + func compute(commandBuffer: MTLCommandBuffer, param: BatchNormParam

) throws { + + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..0cdab6f6b967214a0d9e8a11ae8768b252ead38e --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ConvKernel.swift @@ -0,0 +1,20 @@ +// +// ConvKernel.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/5. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + + +class ConvKernel: Kernel, Computable { + func compute(commandBuffer: MTLCommandBuffer, param: ConvParam

) throws { + + } + required init(device: MTLDevice) { + super.init(device: device, inFunctionName: "conv") + } + +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..a93f04842b1b3ce619e2a679c9283853743f1358 --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ElementwiseAddKernel.swift @@ -0,0 +1,20 @@ +// +// ElementwiseAddKernel.swift +// paddle-mobile +// +// Created by liuRuiLong on 2018/7/5. +// Copyright © 2018年 orange. All rights reserved. +// + +import Foundation + + +class ElementwiseAddKernel: Kernel, Computable { + required init(device: MTLDevice) { + super.init(device: device, inFunctionName: "conv") + } + + func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam

) throws { + + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift index 1fa44c4b975ce543b2a5feafb4a5e19df419a043..b22da765d961837b29b6052413ed4d019f8f1a63 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernel.swift @@ -18,6 +18,12 @@ import Foundation protocol Computable { associatedtype ParamType func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws + init(device: MTLDevice) +} + +protocol KernelProtocol { + var pipline: MTLComputePipelineState { get set } + var functionName: String { get set } } class Kernel { diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal index 0b23f5eb4e398ff787d45829db89a4cc3b3b2032..0d872196c9d19497fa37a35dc9b294843624ca56 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal @@ -1,10 +1,16 @@ -// -// Kernels.metal -// paddle-mobile -// -// Created by liuRuiLong on 2018/7/4. -// Copyright © 2018年 orange. All rights reserved. -// +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include using namespace metal; @@ -16,19 +22,70 @@ struct OutputDim { ushort strideY; }; -kernel void resize( - texture2d inTexture [[texture(0)]], - texture2d outTexture [[texture(1)]], - constant OutputDim ¶ms [[buffer(0)]], - uint2 gid [[thread_position_in_grid]]) { +kernel void resize(texture2d inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + constant OutputDim ¶ms [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || - gid.y >= outTexture.get_height()) { - return; - } + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); const uint2 pos = gid.xy * uint2(params.strideX, params.strideY); const half4 input = inTexture.read(pos); - outTexture.write(half4(input.x, input.y, input.z, 0.0h), gid); + outTexture.write(half4(input.x, input.y, input.z, input.w), gid.xy, gid.z); } + +kernel void relu(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + const float4 relu = fmax((float4)input, 0.0); + outTexture.write(half4(relu), gid.xy, gid.z); +} + + +kernel void elementwise_add(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + const device half4 *biasTerms [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + outTexture.write(input, gid.xy, gid.z); +} + + +kernel void conv(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + const device half4 *biasTerms [[buffer(0)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + outTexture.write(input, gid.xy, gid.z); +} + +kernel void batchnorm(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); + const half4 input = inTexture.read(gid.xy, gid.z); + outTexture.write(input, gid.xy, gid.z); +} + + + + diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift new file mode 100644 index 0000000000000000000000000000000000000000..e14cf9f942d01a720b44e4736d176768944947cd --- /dev/null +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ReluKernel.swift @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +import Foundation + +class ReluKernel: Kernel, Computable{ + func compute(commandBuffer: MTLCommandBuffer, param: ReluParam

) throws { + guard let encoder = commandBuffer.makeComputeCommandEncoder() else { + throw PaddleMobileError.predictError(message: " encode is nil") + } + print(" the usage of input of relu \(param.input.metalTexture.usage)") + encoder.setTexture(param.input.metalTexture, index: 0) + encoder.setTexture(param.output.metalTexture, index: 1) + encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture) + encoder.endEncoding() + } + + required init(device: MTLDevice) { + super.init(device: device, inFunctionName: "relu") + } +} diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift index edd79fb0c005c172a6d687f893be28b1319342ae..9de09ccb98d71eb41ae576ad19677b7f46eea70b 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/ResizeKernel.swift @@ -1,10 +1,16 @@ -// -// ResizeKernel.swift -// paddle-mobile -// -// Created by liuRuiLong on 2018/7/4. -// Copyright © 2018年 orange. All rights reserved. -// +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ import Foundation @@ -22,15 +28,14 @@ struct OutputDim { let strideY: UInt16 } -class ResizeKernel: Kernel, Computable{ +class ResizeKernel: Kernel, Computable{ func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws { guard let encoder = commandBuffer.makeComputeCommandEncoder() else { throw PaddleMobileError.predictError(message: " encode is nil") } encoder.setTexture(param.input, index: 0) - encoder.setTexture(param.output, index: 1) - + encoder.setTexture(param.output, index: 1) let strideX = param.input.width/param.expectDim[2] let strideY = param.input.height/param.expectDim[1] var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY)) @@ -39,7 +44,7 @@ class ResizeKernel: Kernel, Computable{ encoder.endEncoding() } - init(device: MTLDevice) { + required init(device: MTLDevice) { super.init(device: device, inFunctionName: "resize") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift index e9a8d9557ef7d75a58d1a8f7f7e4208efc24a921..d294f9b119c88dc75b78462a60ec0cfc7f2e9bc3 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift @@ -24,19 +24,23 @@ struct ReluParam: OpParam { throw error } } - let input: Texture - var output: Texture + let input: Texture

+ var output: Texture

} -class ReluOp: Operator>, Runable, Creator, InferShaperable{ +class ReluOp: Operator, ReluKernel

>, Runable, Creator, InferShaperable{ func inferShape() { para.output.dim = para.input.dim } typealias OpType = ReluOp

- func runImpl() { - print("this is ReluOp") + func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws { + do { + try kernel.compute(commandBuffer: buffer, param: para) + } catch let error { + throw error + } } } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index 1e650c99e65c1333617f0213a9425b2ffe0eb817..141dc70df2e966516eaa340cbecbaa6ccb722697 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -38,7 +38,7 @@ extension InputTexture { } } -public class Texture: Tensorial { +public class Texture: Tensorial { var dim: Dim let textureDesc: MTLTextureDescriptor var metalTexture: MTLTexture @@ -61,7 +61,15 @@ public class Texture: Tensorial { } else { fatalError(" didn't support yet") } - tmpTextureDes.pixelFormat = .r32Float + if MemoryLayout

.size == 1 { + tmpTextureDes.pixelFormat = .r8Sint + } else if MemoryLayout

.size == 2 { + tmpTextureDes.pixelFormat = .r16Float + } else if MemoryLayout

.size == 4 { + tmpTextureDes.pixelFormat = .r32Float + } + + tmpTextureDes.usage = .unknown tmpTextureDes.storageMode = .shared textureDesc = tmpTextureDes metalTexture = device.makeTexture(descriptor: tmpTextureDes) ?! " texture nil "