diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 125fd5ec745f635929df2d4e29f8147ab9a6b83a..e2c6b20680bd2f1209dfea198f9f62f2171716cf 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile-demo.xcscheme orderHint - 4 + 3 diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift index 875f61726c404f941937aacd8ca077cf32ae765e..2a8ee13525d4a511aef014638d142939bc0ed492 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift @@ -24,8 +24,9 @@ class ViewController: UIViewController { let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let program = try loader.load(modelPath: modelPath, paraPath: paraPath) - let executor = try Executor.init(program: program) - executor.predict() + let executor = try Executor.init(inProgram: program) + let output = try executor.predict(input: Texture.init()) + print(output) } catch let error { print(error) } diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 4c3e982aa8e35e747e582d988b57538dcf946829..27017d3f89b79d090f8f8089217ad7634829e6ba 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -155,12 +155,12 @@ FC039BA320E11CBC0081E9F8 /* Operators */ = { isa = PBXGroup; children = ( + FCD592FA20E248EC00252966 /* Base */, + FCD592F920E248EC00252966 /* Kernels */, FC039BA420E11CBC0081E9F8 /* ConvOp.swift */, FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */, - FC039BA620E11CBC0081E9F8 /* Operator.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA820E11CBC0081E9F8 /* ReluOp.swift */, - FC9D037820E229E4000F735A /* OpParam.swift */, FC9D037F20E22FBB000F735A /* FeedOp.swift */, FC9D038120E2312E000F735A /* FetchOp.swift */, ); @@ -183,6 +183,22 @@ path = Program; sourceTree = ""; }; + FCD592F920E248EC00252966 /* Kernels */ = { + isa = PBXGroup; + children = ( + ); + path = Kernels; + sourceTree = ""; + }; + FCD592FA20E248EC00252966 /* Base */ = { + isa = PBXGroup; + children = ( + FC9D037820E229E4000F735A /* OpParam.swift */, + FC039BA620E11CBC0081E9F8 /* Operator.swift */, + ); + path = Base; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXHeadersBuildPhase section */ diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 52725213607bd68fccbd4aa63faa9294cb622962..50f16e4d7cfc755905c68ced31769d3ac1dd1049 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile.xcscheme orderHint - 3 + 4 diff --git a/metal/paddle-mobile/paddle-mobile/Common/Types.swift b/metal/paddle-mobile/paddle-mobile/Common/Types.swift index 38a7a9d70ea87a46a1bc58e5ddb539fbf14041bb..d2dd95b895796aba6091926bf745ad2af1d57740 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Types.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Types.swift @@ -60,7 +60,7 @@ class OpCreator { } } - func creat(opDesc: OpDesc, scope: Scope) throws -> Runable { + func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { guard let opCreator = opCreators[opDesc.type] else { throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") } @@ -72,7 +72,7 @@ class OpCreator { } } - let opCreators: [String : (OpDesc, Scope) throws -> Runable] = + let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] = [gConvType : ConvOp

.creat, gBatchNormType : BatchNormOp

.creat, gReluType : ReluOp

.creat, diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 7274e7c0d86fd8d44bd83f515f68ca3c4932742b..92370bd724378e412dd34344c8769f21b7fc5d3e 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -15,17 +15,15 @@ import Foundation public class Executor { - var ops: [Runable] = [] - public init(program: Program) throws { - for block in program.programDesc.blocks { - for varDesc in block.vars { - if !varDesc.persistable { - program.scope.vars[varDesc.name] = Texture.init() - } - } + var ops: [Runable & InferShaperable] = [] + let program: Program + public init(inProgram: Program) throws { + program = inProgram + for block in inProgram.programDesc.blocks { for op in block.ops { do { - let op = try OpCreator

.shared.creat(opDesc: op, scope: program.scope) + let op = try OpCreator

.shared.creat(opDesc: op, scope: inProgram.scope) + op.inferShape() ops.append(op) } catch let error { throw error @@ -34,10 +32,16 @@ public class Executor { } } - public func predict() { + public func predict(input: Texture) throws -> Texture { + program.scope[program.feedKey] = input for op in ops { op.run() } + let outputVar = program.scope[program.fetchKey] + guard let output = outputVar as? Texture else { + throw PaddleMobileError.netError(message: "output var type error") + } + return output } } diff --git a/metal/paddle-mobile/paddle-mobile/Loader.swift b/metal/paddle-mobile/paddle-mobile/Loader.swift index f5275d37417c0a277aed04b720ba09c25e063685..acc6903b5db80bbae5b8bc1900899ca41ad309d0 100644 --- a/metal/paddle-mobile/paddle-mobile/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Loader.swift @@ -156,14 +156,20 @@ public class Loader { let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( serializedData: modelData) let scope = Scope.init() - let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope) + let programDesc = ProgramDesc.init(protoProgram: protoProgram) guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else { throw PaddleMobileError.loaderError(message: "load para error") } - - for block in program.programDesc.blocks { + + guard programDesc.blocks.count > 0 else { + throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0") + } + + for block in programDesc.blocks { for varDesc in block.vars { + print(varDesc.name + "\(varDesc.persistable)") + if (varDesc.type == .LodTensor) { if (varDesc.persistable && varDesc.type != .FeedMiniBatch @@ -192,11 +198,33 @@ public class Loader { } paraData.convert(to: .NHWC) let tensor = Tensor

.init(inData: paraData) - scope.vars[varDesc.name] = tensor + scope[varDesc.name] = tensor + } else { + scope[varDesc.name] = Texture.init() } + } else { + scope[varDesc.name] = Texture.init() } } } + + let block = programDesc.blocks[0] + guard let firstOp = block.ops.first, let lastOp = block.ops.last else { + throw PaddleMobileError.loaderError(message: "at least two operator") + } + guard firstOp.type == gFeedType, lastOp.type == gFetchType else { + throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch") + } + + guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else { + throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found") + } + guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else { + throw PaddleMobileError.loaderError(message: "feed key or fetch key not found") + } + + let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey) + return program } catch _ { throw PaddleMobileError.loaderError(message: "protobuf decoder error") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift similarity index 100% rename from metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift rename to metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift similarity index 95% rename from metal/paddle-mobile/paddle-mobile/Operators/Operator.swift rename to metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index 05434c2b956423e32d539dfba417f245c3037ab6..7255c73773c35484fa392373ba5d45a654d06b45 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -14,7 +14,6 @@ import Foundation - protocol Runable { func run() func runImpl() @@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{ } protocol Creator where Self: OperatorProtocol{ - associatedtype OpType: OperatorProtocol + associatedtype OpType: OperatorProtocol & Runable & InferShaperable static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType } @@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol { } } +protocol InferShaperable { + func inferShape() +} + + protocol OperatorProtocol { associatedtype ParamType: OpParam var type: String { get } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 706e8002607112570beb07835ca1d28b7fc3b959..7a5733a70a854d972e477a0c5f57f3da517dcd22 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -18,8 +18,8 @@ struct BatchNormParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) - outputY = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) + input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) + output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope) inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope) inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope) @@ -31,8 +31,8 @@ struct BatchNormParam: OpParam { throw error } } - let inputX: Texture - let outputY: Texture + let input: Texture + let output: Texture let inputBias: Tensor let inputMean: Tensor let inputScale: Tensor @@ -42,7 +42,10 @@ struct BatchNormParam: OpParam { let is_test: Bool } -class BatchNormOp: Operator>, Runable, Creator{ +class BatchNormOp: Operator>, Runable, Creator, InferShaperable{ + func inferShape() { + para.output.dim = para.input.dim + } typealias OpType = BatchNormOp

func runImpl() { print("this is BatchNormOp") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift index 9d3d0d8fed518c0954d452a3a162727c70a6efb8..69fa76fd3a62ac875347f6c19ada6a88a579f809 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift @@ -39,7 +39,28 @@ struct ConvParam: OpParam { let groups: Int } -class ConvOp: Operator>, Runable, Creator { +class ConvOp: Operator>, Runable, Creator, InferShaperable { + func inferShape() { +// let inDims = para.input.dim +// let filterDim = para.filter.dim +// let strides = para.stride +// let paddings = para.paddings +// let dilations = para.dilations +// +// var outDim = [inDims[0]] +// for i in 0.. func runImpl() { print("this is conv") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift index 5ae2a6d388cfe6d810e6d38bbe904c96b8728c8d..8cae36254ec2bef687993ca086610cc1500be339 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift @@ -18,21 +18,26 @@ struct ElementwiseAddParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) + input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope) - out = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) + output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) } catch let error { throw error } } - let inputX: Texture + let input: Texture let inputY: Tensor

- let out: Texture + let output: Texture let axis: Int } -class ElementwiseAddOp: Operator>, Runable, Creator{ +class ElementwiseAddOp: Operator>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = ElementwiseAddOp

func runImpl() { print("this is ElementwiseAddOp") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift index c82aa10d773ac2673c563b5f1a798ae7f8c67999..8c4bee3aa6ca6b6c80b4f2d25d6d0c7415718c86 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift @@ -15,15 +15,28 @@ import Foundation struct FeedParam: OpParam{ + var output: Texture + var input: Texture + init(opDesc: OpDesc, scope: Scope) throws { - + do { + input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope) + output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope) + } catch let error { + throw error + } } typealias ParamPrecisionType = P } -class FeedOp: Operator>, Runable, Creator { +class FeedOp: Operator>, Runable, Creator, InferShaperable { typealias OpType = FeedOp

+ + func inferShape() { + para.output.dim = para.input.dim + } + func runImpl() { print("feed op") } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift index 5e17bfefe2abd083556e36419c7ede4be88c6c3f..bd333bff0ae39bfcbe57b0d55eb91bcd786f5dad 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift @@ -15,17 +15,29 @@ import Foundation struct FetchParam: OpParam{ + var input: Texture + var output: Texture + init(opDesc: OpDesc, scope: Scope) throws { - + do { + input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope) + output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope) + } catch let error { + throw error + } } typealias ParamPrecisionType = P } -class FetchOp: Operator>, Runable, Creator { +class FetchOp: Operator>, Runable, Creator, InferShaperable{ + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = FetchOp

func runImpl() { - print("feed op") + print("fetch op") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift index 59fff1c18633f8423c28aceb4f9bbff02aac4c4b..d641d13ecbcf1f4e70a0ff056eab5c68b4fc30ae 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift @@ -18,17 +18,22 @@ struct ReluParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) - out = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) + input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) + output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) } catch let error { throw error } } - let inputX: Texture - let out: Texture + let input: Texture + let output: Texture } -class ReluOp: Operator>, Runable, Creator{ +class ReluOp: Operator>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = ReluOp

func runImpl() { print("this is ReluOp") diff --git a/metal/paddle-mobile/paddle-mobile/Program/Program.swift b/metal/paddle-mobile/paddle-mobile/Program/Program.swift index a346af8304688ba766928c6c6132dc7b840e683d..5a93e79338b5d694178313064cf49cfbb0d969db 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/Program.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/Program.swift @@ -16,11 +16,15 @@ import Foundation public struct Program { let paramPath: String + let feedKey: String + let fetchKey: String let programDesc: ProgramDesc let scope: Scope - init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) { + init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) { programDesc = ProgramDesc.init(protoProgram: protoProgramDesc) paramPath = inParamPath scope = inScope + feedKey = inFeedKey + fetchKey = inFetchKey } } diff --git a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift index 0f34ed20ddf7feed4e22c8ce330a3343a41edc93..7ad95fa5357941c5c24163b37a7eaf6961e4be7e 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift @@ -17,7 +17,13 @@ import Foundation class Scope { var vars: [String : Variant] = [:] subscript(key: String) -> Variant?{ - return vars[key] + get { + return vars[key] + } + set { + vars[key] = newValue + } + } } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift index fa96f23b74832b0244535ddf28b221d5fdbb9bc7..eb5c57a5611eafbb4fc50742f21f22e5b7a57903 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift @@ -14,7 +14,7 @@ import Foundation -struct Dim { +public struct Dim { init(inDim: [Int]) { dims = inDim } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift index 7d00e9b8687674cf955910f3d208fb3140a7032f..8183f80523bf903373b4b99e1f4460ce76367477 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift @@ -14,10 +14,28 @@ import Foundation -class Tensor { +protocol Tensorial { + var dim: Dim { get set } + func numel() -> Int + func dataLayout() -> DataLayout +} + +extension Tensorial { + func numel() -> Int { + return dim.numel() + } +} + +class Tensor : Tensorial { var dim: Dim { - return paraData.dim + get { + return paraData.dim + } + set { + paraData.dim = newValue + } } + let paraData: ParamData

init(inDimArray: [Int], inData: ParamData

) { paraData = inData @@ -25,9 +43,6 @@ class Tensor { init(inData: ParamData

) { paraData = inData } - func numel() -> Int { - return dim.numel() - } func dataLayout() -> DataLayout { return paraData.layout diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index 5acd14fbb0a8e8a6bed24b73ab8315b11a87f817..e62fb1d8d7ed10d2cf7ca94f6a3df6496090f07f 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -12,8 +12,25 @@ See the License for the specific language governing permissions and limitations under the License. */ +import Metal import Foundation -class Texture { +public class Texture: Tensorial { + var dim: Dim +// let texture: MTLTexture + func dataLayout() -> DataLayout { + return .NHWC + } + + public init(inTexture: MTLTexture, inDim: Dim) { +// texture = inTexture + dim = inDim + } + + public init() { + dim = Dim.init(inDim: []) + +// fatalError() + } }