From ab9d9c5d797d047538113aeb147ba3770482594c Mon Sep 17 00:00:00 2001 From: liuruilong Date: Wed, 27 Jun 2018 12:21:37 +0800 Subject: [PATCH] add infer shape of metal --- .../xcschemes/xcschememanagement.plist | 2 +- .../paddle-mobile-demo/ViewController.swift | 5 +-- .../paddle-mobile.xcodeproj/project.pbxproj | 20 +++++++++-- .../xcschemes/xcschememanagement.plist | 2 +- .../paddle-mobile/Common/Types.swift | 4 +-- .../paddle-mobile/Executor.swift | 24 +++++++------ .../paddle-mobile/paddle-mobile/Loader.swift | 36 ++++++++++++++++--- .../Operators/{ => Base}/OpParam.swift | 0 .../Operators/{ => Base}/Operator.swift | 8 +++-- .../paddle-mobile/Operators/BatchNormOp.swift | 13 ++++--- .../paddle-mobile/Operators/ConvOp.swift | 23 +++++++++++- .../Operators/ElementwiseAddOp.swift | 15 +++++--- .../paddle-mobile/Operators/FeedOp.swift | 17 +++++++-- .../paddle-mobile/Operators/FetchOp.swift | 18 ++++++++-- .../paddle-mobile/Operators/ReluOp.swift | 15 +++++--- .../paddle-mobile/Program/Program.swift | 6 +++- .../paddle-mobile/Program/Scope.swift | 8 ++++- .../paddle-mobile/framework/Dim.swift | 2 +- .../paddle-mobile/framework/Tensor.swift | 25 ++++++++++--- .../paddle-mobile/framework/Texture.swift | 19 +++++++++- 20 files changed, 208 insertions(+), 54 deletions(-) rename metal/paddle-mobile/paddle-mobile/Operators/{ => Base}/OpParam.swift (100%) rename metal/paddle-mobile/paddle-mobile/Operators/{ => Base}/Operator.swift (95%) diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 125fd5ec74..e2c6b20680 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile-demo/paddle-mobile-demo.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile-demo.xcscheme orderHint - 4 + 3 diff --git a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift index 875f61726c..2a8ee13525 100644 --- a/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift +++ b/metal/paddle-mobile-demo/paddle-mobile-demo/ViewController.swift @@ -24,8 +24,9 @@ class ViewController: UIViewController { let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let program = try loader.load(modelPath: modelPath, paraPath: paraPath) - let executor = try Executor.init(program: program) - executor.predict() + let executor = try Executor.init(inProgram: program) + let output = try executor.predict(input: Texture.init()) + print(output) } catch let error { print(error) } diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index 4c3e982aa8..27017d3f89 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -155,12 +155,12 @@ FC039BA320E11CBC0081E9F8 /* Operators */ = { isa = PBXGroup; children = ( + FCD592FA20E248EC00252966 /* Base */, + FCD592F920E248EC00252966 /* Kernels */, FC039BA420E11CBC0081E9F8 /* ConvOp.swift */, FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */, - FC039BA620E11CBC0081E9F8 /* Operator.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA820E11CBC0081E9F8 /* ReluOp.swift */, - FC9D037820E229E4000F735A /* OpParam.swift */, FC9D037F20E22FBB000F735A /* FeedOp.swift */, FC9D038120E2312E000F735A /* FetchOp.swift */, ); @@ -183,6 +183,22 @@ path = Program; sourceTree = ""; }; + FCD592F920E248EC00252966 /* Kernels */ = { + isa = PBXGroup; + children = ( + ); + path = Kernels; + sourceTree = ""; + }; + FCD592FA20E248EC00252966 /* Base */ = { + isa = PBXGroup; + children = ( + FC9D037820E229E4000F735A /* OpParam.swift */, + FC039BA620E11CBC0081E9F8 /* Operator.swift */, + ); + path = Base; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXHeadersBuildPhase section */ diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist index 5272521360..50f16e4d7c 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/xcuserdata/liuruilong.xcuserdatad/xcschemes/xcschememanagement.plist @@ -7,7 +7,7 @@ paddle-mobile.xcscheme orderHint - 3 + 4 diff --git a/metal/paddle-mobile/paddle-mobile/Common/Types.swift b/metal/paddle-mobile/paddle-mobile/Common/Types.swift index 38a7a9d70e..d2dd95b895 100644 --- a/metal/paddle-mobile/paddle-mobile/Common/Types.swift +++ b/metal/paddle-mobile/paddle-mobile/Common/Types.swift @@ -60,7 +60,7 @@ class OpCreator { } } - func creat(opDesc: OpDesc, scope: Scope) throws -> Runable { + func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable { guard let opCreator = opCreators[opDesc.type] else { throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") } @@ -72,7 +72,7 @@ class OpCreator { } } - let opCreators: [String : (OpDesc, Scope) throws -> Runable] = + let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] = [gConvType : ConvOp

.creat, gBatchNormType : BatchNormOp

.creat, gReluType : ReluOp

.creat, diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift index 7274e7c0d8..92370bd724 100644 --- a/metal/paddle-mobile/paddle-mobile/Executor.swift +++ b/metal/paddle-mobile/paddle-mobile/Executor.swift @@ -15,17 +15,15 @@ import Foundation public class Executor { - var ops: [Runable] = [] - public init(program: Program) throws { - for block in program.programDesc.blocks { - for varDesc in block.vars { - if !varDesc.persistable { - program.scope.vars[varDesc.name] = Texture.init() - } - } + var ops: [Runable & InferShaperable] = [] + let program: Program + public init(inProgram: Program) throws { + program = inProgram + for block in inProgram.programDesc.blocks { for op in block.ops { do { - let op = try OpCreator

.shared.creat(opDesc: op, scope: program.scope) + let op = try OpCreator

.shared.creat(opDesc: op, scope: inProgram.scope) + op.inferShape() ops.append(op) } catch let error { throw error @@ -34,10 +32,16 @@ public class Executor { } } - public func predict() { + public func predict(input: Texture) throws -> Texture { + program.scope[program.feedKey] = input for op in ops { op.run() } + let outputVar = program.scope[program.fetchKey] + guard let output = outputVar as? Texture else { + throw PaddleMobileError.netError(message: "output var type error") + } + return output } } diff --git a/metal/paddle-mobile/paddle-mobile/Loader.swift b/metal/paddle-mobile/paddle-mobile/Loader.swift index f5275d3741..acc6903b5d 100644 --- a/metal/paddle-mobile/paddle-mobile/Loader.swift +++ b/metal/paddle-mobile/paddle-mobile/Loader.swift @@ -156,14 +156,20 @@ public class Loader { let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( serializedData: modelData) let scope = Scope.init() - let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope) + let programDesc = ProgramDesc.init(protoProgram: protoProgram) guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else { throw PaddleMobileError.loaderError(message: "load para error") } - - for block in program.programDesc.blocks { + + guard programDesc.blocks.count > 0 else { + throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0") + } + + for block in programDesc.blocks { for varDesc in block.vars { + print(varDesc.name + "\(varDesc.persistable)") + if (varDesc.type == .LodTensor) { if (varDesc.persistable && varDesc.type != .FeedMiniBatch @@ -192,11 +198,33 @@ public class Loader { } paraData.convert(to: .NHWC) let tensor = Tensor

.init(inData: paraData) - scope.vars[varDesc.name] = tensor + scope[varDesc.name] = tensor + } else { + scope[varDesc.name] = Texture.init() } + } else { + scope[varDesc.name] = Texture.init() } } } + + let block = programDesc.blocks[0] + guard let firstOp = block.ops.first, let lastOp = block.ops.last else { + throw PaddleMobileError.loaderError(message: "at least two operator") + } + guard firstOp.type == gFeedType, lastOp.type == gFetchType else { + throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch") + } + + guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else { + throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found") + } + guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else { + throw PaddleMobileError.loaderError(message: "feed key or fetch key not found") + } + + let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey) + return program } catch _ { throw PaddleMobileError.loaderError(message: "protobuf decoder error") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift similarity index 100% rename from metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift rename to metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift similarity index 95% rename from metal/paddle-mobile/paddle-mobile/Operators/Operator.swift rename to metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift index 05434c2b95..7255c73773 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift @@ -14,7 +14,6 @@ import Foundation - protocol Runable { func run() func runImpl() @@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{ } protocol Creator where Self: OperatorProtocol{ - associatedtype OpType: OperatorProtocol + associatedtype OpType: OperatorProtocol & Runable & InferShaperable static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType } @@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol { } } +protocol InferShaperable { + func inferShape() +} + + protocol OperatorProtocol { associatedtype ParamType: OpParam var type: String { get } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift index 706e800260..7a5733a70a 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift @@ -18,8 +18,8 @@ struct BatchNormParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) - outputY = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) + input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) + output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope) inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope) inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope) @@ -31,8 +31,8 @@ struct BatchNormParam: OpParam { throw error } } - let inputX: Texture - let outputY: Texture + let input: Texture + let output: Texture let inputBias: Tensor let inputMean: Tensor let inputScale: Tensor @@ -42,7 +42,10 @@ struct BatchNormParam: OpParam { let is_test: Bool } -class BatchNormOp: Operator>, Runable, Creator{ +class BatchNormOp: Operator>, Runable, Creator, InferShaperable{ + func inferShape() { + para.output.dim = para.input.dim + } typealias OpType = BatchNormOp

func runImpl() { print("this is BatchNormOp") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift index 9d3d0d8fed..69fa76fd3a 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift @@ -39,7 +39,28 @@ struct ConvParam: OpParam { let groups: Int } -class ConvOp: Operator>, Runable, Creator { +class ConvOp: Operator>, Runable, Creator, InferShaperable { + func inferShape() { +// let inDims = para.input.dim +// let filterDim = para.filter.dim +// let strides = para.stride +// let paddings = para.paddings +// let dilations = para.dilations +// +// var outDim = [inDims[0]] +// for i in 0.. func runImpl() { print("this is conv") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift index 5ae2a6d388..8cae36254e 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ElementwiseAddOp.swift @@ -18,21 +18,26 @@ struct ElementwiseAddParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) + input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope) - out = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) + output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) } catch let error { throw error } } - let inputX: Texture + let input: Texture let inputY: Tensor

- let out: Texture + let output: Texture let axis: Int } -class ElementwiseAddOp: Operator>, Runable, Creator{ +class ElementwiseAddOp: Operator>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = ElementwiseAddOp

func runImpl() { print("this is ElementwiseAddOp") diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift index c82aa10d77..8c4bee3aa6 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift @@ -15,15 +15,28 @@ import Foundation struct FeedParam: OpParam{ + var output: Texture + var input: Texture + init(opDesc: OpDesc, scope: Scope) throws { - + do { + input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope) + output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope) + } catch let error { + throw error + } } typealias ParamPrecisionType = P } -class FeedOp: Operator>, Runable, Creator { +class FeedOp: Operator>, Runable, Creator, InferShaperable { typealias OpType = FeedOp

+ + func inferShape() { + para.output.dim = para.input.dim + } + func runImpl() { print("feed op") } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift index 5e17bfefe2..bd333bff0a 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift @@ -15,17 +15,29 @@ import Foundation struct FetchParam: OpParam{ + var input: Texture + var output: Texture + init(opDesc: OpDesc, scope: Scope) throws { - + do { + input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope) + output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope) + } catch let error { + throw error + } } typealias ParamPrecisionType = P } -class FetchOp: Operator>, Runable, Creator { +class FetchOp: Operator>, Runable, Creator, InferShaperable{ + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = FetchOp

func runImpl() { - print("feed op") + print("fetch op") } } diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift index 59fff1c186..d641d13ecb 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift @@ -18,17 +18,22 @@ struct ReluParam: OpParam { typealias ParamPrecisionType = P init(opDesc: OpDesc, scope: Scope) throws { do { - inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) - out = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) + input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) + output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) } catch let error { throw error } } - let inputX: Texture - let out: Texture + let input: Texture + let output: Texture } -class ReluOp: Operator>, Runable, Creator{ +class ReluOp: Operator>, Runable, Creator, InferShaperable{ + + func inferShape() { + para.output.dim = para.input.dim + } + typealias OpType = ReluOp

func runImpl() { print("this is ReluOp") diff --git a/metal/paddle-mobile/paddle-mobile/Program/Program.swift b/metal/paddle-mobile/paddle-mobile/Program/Program.swift index a346af8304..5a93e79338 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/Program.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/Program.swift @@ -16,11 +16,15 @@ import Foundation public struct Program { let paramPath: String + let feedKey: String + let fetchKey: String let programDesc: ProgramDesc let scope: Scope - init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) { + init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) { programDesc = ProgramDesc.init(protoProgram: protoProgramDesc) paramPath = inParamPath scope = inScope + feedKey = inFeedKey + fetchKey = inFetchKey } } diff --git a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift index 0f34ed20dd..7ad95fa535 100644 --- a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift +++ b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift @@ -17,7 +17,13 @@ import Foundation class Scope { var vars: [String : Variant] = [:] subscript(key: String) -> Variant?{ - return vars[key] + get { + return vars[key] + } + set { + vars[key] = newValue + } + } } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift index fa96f23b74..eb5c57a561 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift @@ -14,7 +14,7 @@ import Foundation -struct Dim { +public struct Dim { init(inDim: [Int]) { dims = inDim } diff --git a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift index 7d00e9b868..8183f80523 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift @@ -14,10 +14,28 @@ import Foundation -class Tensor { +protocol Tensorial { + var dim: Dim { get set } + func numel() -> Int + func dataLayout() -> DataLayout +} + +extension Tensorial { + func numel() -> Int { + return dim.numel() + } +} + +class Tensor : Tensorial { var dim: Dim { - return paraData.dim + get { + return paraData.dim + } + set { + paraData.dim = newValue + } } + let paraData: ParamData

init(inDimArray: [Int], inData: ParamData

) { paraData = inData @@ -25,9 +43,6 @@ class Tensor { init(inData: ParamData

) { paraData = inData } - func numel() -> Int { - return dim.numel() - } func dataLayout() -> DataLayout { return paraData.layout diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index 5acd14fbb0..e62fb1d8d7 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -12,8 +12,25 @@ See the License for the specific language governing permissions and limitations under the License. */ +import Metal import Foundation -class Texture { +public class Texture: Tensorial { + var dim: Dim +// let texture: MTLTexture + func dataLayout() -> DataLayout { + return .NHWC + } + + public init(inTexture: MTLTexture, inDim: Dim) { +// texture = inTexture + dim = inDim + } + + public init() { + dim = Dim.init(inDim: []) + +// fatalError() + } } -- GitLab