提交 888877ff 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #460 from codeWorm2015/metal

add infer shape of metal
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key> <key>paddle-mobile-demo.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>4</integer> <integer>3</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -24,8 +24,9 @@ class ViewController: UIViewController { ...@@ -24,8 +24,9 @@ class ViewController: UIViewController {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null" let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null" let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath) let program = try loader.load(modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(program: program) let executor = try Executor<Float>.init(inProgram: program)
executor.predict() let output = try executor.predict(input: Texture.init())
print(output)
} catch let error { } catch let error {
print(error) print(error)
} }
......
...@@ -155,12 +155,12 @@ ...@@ -155,12 +155,12 @@
FC039BA320E11CBC0081E9F8 /* Operators */ = { FC039BA320E11CBC0081E9F8 /* Operators */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FCD592FA20E248EC00252966 /* Base */,
FCD592F920E248EC00252966 /* Kernels */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */, FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */, FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA620E11CBC0081E9F8 /* Operator.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
FC039BA820E11CBC0081E9F8 /* ReluOp.swift */, FC039BA820E11CBC0081E9F8 /* ReluOp.swift */,
FC9D037820E229E4000F735A /* OpParam.swift */,
FC9D037F20E22FBB000F735A /* FeedOp.swift */, FC9D037F20E22FBB000F735A /* FeedOp.swift */,
FC9D038120E2312E000F735A /* FetchOp.swift */, FC9D038120E2312E000F735A /* FetchOp.swift */,
); );
...@@ -183,6 +183,22 @@ ...@@ -183,6 +183,22 @@
path = Program; path = Program;
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FCD592F920E248EC00252966 /* Kernels */ = {
isa = PBXGroup;
children = (
);
path = Kernels;
sourceTree = "<group>";
};
FCD592FA20E248EC00252966 /* Base */ = {
isa = PBXGroup;
children = (
FC9D037820E229E4000F735A /* OpParam.swift */,
FC039BA620E11CBC0081E9F8 /* Operator.swift */,
);
path = Base;
sourceTree = "<group>";
};
/* End PBXGroup section */ /* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */ /* Begin PBXHeadersBuildPhase section */
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key> <key>paddle-mobile.xcscheme</key>
<dict> <dict>
<key>orderHint</key> <key>orderHint</key>
<integer>3</integer> <integer>4</integer>
</dict> </dict>
</dict> </dict>
</dict> </dict>
......
...@@ -60,7 +60,7 @@ class OpCreator<P: PrecisionType> { ...@@ -60,7 +60,7 @@ class OpCreator<P: PrecisionType> {
} }
} }
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable { func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else { guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet") throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
} }
...@@ -72,7 +72,7 @@ class OpCreator<P: PrecisionType> { ...@@ -72,7 +72,7 @@ class OpCreator<P: PrecisionType> {
} }
} }
let opCreators: [String : (OpDesc, Scope) throws -> Runable] = let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat, [gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat, gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat, gReluType : ReluOp<P>.creat,
......
...@@ -15,17 +15,15 @@ ...@@ -15,17 +15,15 @@
import Foundation import Foundation
public class Executor<P: PrecisionType> { public class Executor<P: PrecisionType> {
var ops: [Runable] = [] var ops: [Runable & InferShaperable] = []
public init(program: Program) throws { let program: Program
for block in program.programDesc.blocks { public init(inProgram: Program) throws {
for varDesc in block.vars { program = inProgram
if !varDesc.persistable { for block in inProgram.programDesc.blocks {
program.scope.vars[varDesc.name] = Texture.init()
}
}
for op in block.ops { for op in block.ops {
do { do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: program.scope) let op = try OpCreator<P>.shared.creat(opDesc: op, scope: inProgram.scope)
op.inferShape()
ops.append(op) ops.append(op)
} catch let error { } catch let error {
throw error throw error
...@@ -34,10 +32,16 @@ public class Executor<P: PrecisionType> { ...@@ -34,10 +32,16 @@ public class Executor<P: PrecisionType> {
} }
} }
public func predict() { public func predict(input: Texture) throws -> Texture {
program.scope[program.feedKey] = input
for op in ops { for op in ops {
op.run() op.run()
} }
let outputVar = program.scope[program.fetchKey]
guard let output = outputVar as? Texture else {
throw PaddleMobileError.netError(message: "output var type error")
}
return output
} }
} }
......
...@@ -156,14 +156,20 @@ public class Loader<P: PrecisionType> { ...@@ -156,14 +156,20 @@ public class Loader<P: PrecisionType> {
let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init( let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
serializedData: modelData) serializedData: modelData)
let scope = Scope.init() let scope = Scope.init()
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope) let programDesc = ProgramDesc.init(protoProgram: protoProgram)
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else { guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
throw PaddleMobileError.loaderError(message: "load para error") throw PaddleMobileError.loaderError(message: "load para error")
} }
for block in program.programDesc.blocks { guard programDesc.blocks.count > 0 else {
throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0")
}
for block in programDesc.blocks {
for varDesc in block.vars { for varDesc in block.vars {
print(varDesc.name + "\(varDesc.persistable)")
if (varDesc.type == .LodTensor) { if (varDesc.type == .LodTensor) {
if (varDesc.persistable if (varDesc.persistable
&& varDesc.type != .FeedMiniBatch && varDesc.type != .FeedMiniBatch
...@@ -192,11 +198,33 @@ public class Loader<P: PrecisionType> { ...@@ -192,11 +198,33 @@ public class Loader<P: PrecisionType> {
} }
paraData.convert(to: .NHWC) paraData.convert(to: .NHWC)
let tensor = Tensor<P>.init(inData: paraData) let tensor = Tensor<P>.init(inData: paraData)
scope.vars[varDesc.name] = tensor scope[varDesc.name] = tensor
} else {
scope[varDesc.name] = Texture.init()
} }
} else {
scope[varDesc.name] = Texture.init()
} }
} }
} }
let block = programDesc.blocks[0]
guard let firstOp = block.ops.first, let lastOp = block.ops.last else {
throw PaddleMobileError.loaderError(message: "at least two operator")
}
guard firstOp.type == gFeedType, lastOp.type == gFetchType else {
throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch")
}
guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else {
throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found")
}
guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
}
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey)
return program return program
} catch _ { } catch _ {
throw PaddleMobileError.loaderError(message: "protobuf decoder error") throw PaddleMobileError.loaderError(message: "protobuf decoder error")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import Foundation import Foundation
protocol Runable { protocol Runable {
func run() func run()
func runImpl() func runImpl()
...@@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{ ...@@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{
} }
protocol Creator where Self: OperatorProtocol{ protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
} }
...@@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol { ...@@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol {
} }
} }
protocol InferShaperable {
func inferShape()
}
protocol OperatorProtocol { protocol OperatorProtocol {
associatedtype ParamType: OpParam associatedtype ParamType: OpParam
var type: String { get } var type: String { get }
......
...@@ -18,8 +18,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -18,8 +18,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
outputY = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope) output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope) inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope) inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope) inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope)
...@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw error throw error
} }
} }
let inputX: Texture let input: Texture
let outputY: Texture let output: Texture
let inputBias: Tensor<ParamPrecisionType> let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType> let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType> let inputScale: Tensor<ParamPrecisionType>
...@@ -42,7 +42,10 @@ struct BatchNormParam<P: PrecisionType>: OpParam { ...@@ -42,7 +42,10 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let is_test: Bool let is_test: Bool
} }
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator{ class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = BatchNormOp<P> typealias OpType = BatchNormOp<P>
func runImpl() { func runImpl() {
print("this is BatchNormOp") print("this is BatchNormOp")
......
...@@ -39,7 +39,28 @@ struct ConvParam<P: PrecisionType>: OpParam { ...@@ -39,7 +39,28 @@ struct ConvParam<P: PrecisionType>: OpParam {
let groups: Int let groups: Int
} }
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator { class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable {
func inferShape() {
// let inDims = para.input.dim
// let filterDim = para.filter.dim
// let strides = para.stride
// let paddings = para.paddings
// let dilations = para.dilations
//
// var outDim = [inDims[0]]
// for i in 0..<strides.count {
// let dilation: Int = Int(dilations[i])
// let filterSize: Int = filterDim[i + 1]
// let inputSize: Int = inDims[i + 1]
// let padding: Int = Int(paddings[i])
// let stride: Int = Int(strides[i])
// let dKernel = dilation * (filterSize - 1) + 1
// let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
// outDim.append(outputSize)
// }
// outDim.append(filterDim[0])
}
typealias OpType = ConvOp<P> typealias OpType = ConvOp<P>
func runImpl() { func runImpl() {
print("this is conv") print("this is conv")
......
...@@ -18,21 +18,26 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam { ...@@ -18,21 +18,26 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope) inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope)
out = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope) output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs) axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error { } catch let error {
throw error throw error
} }
} }
let inputX: Texture let input: Texture
let inputY: Tensor<P> let inputY: Tensor<P>
let out: Texture let output: Texture
let axis: Int let axis: Int
} }
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator{ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ElementwiseAddOp<P> typealias OpType = ElementwiseAddOp<P>
func runImpl() { func runImpl() {
print("this is ElementwiseAddOp") print("this is ElementwiseAddOp")
......
...@@ -15,15 +15,28 @@ ...@@ -15,15 +15,28 @@
import Foundation import Foundation
struct FeedParam<P: PrecisionType>: OpParam{ struct FeedParam<P: PrecisionType>: OpParam{
var output: Texture
var input: Texture
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do {
input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error {
throw error
}
} }
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
} }
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator { class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P> typealias OpType = FeedOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl() { func runImpl() {
print("feed op") print("feed op")
} }
......
...@@ -15,17 +15,29 @@ ...@@ -15,17 +15,29 @@
import Foundation import Foundation
struct FetchParam<P: PrecisionType>: OpParam{ struct FetchParam<P: PrecisionType>: OpParam{
var input: Texture
var output: Texture
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error {
throw error
}
} }
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
} }
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator { class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = FetchOp<P> typealias OpType = FetchOp<P>
func runImpl() { func runImpl() {
print("feed op") print("fetch op")
} }
} }
...@@ -18,17 +18,22 @@ struct ReluParam<P: PrecisionType>: OpParam { ...@@ -18,17 +18,22 @@ struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
out = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope) output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error { } catch let error {
throw error throw error
} }
} }
let inputX: Texture let input: Texture
let out: Texture let output: Texture
} }
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator{ class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ReluOp<P> typealias OpType = ReluOp<P>
func runImpl() { func runImpl() {
print("this is ReluOp") print("this is ReluOp")
......
...@@ -16,11 +16,15 @@ import Foundation ...@@ -16,11 +16,15 @@ import Foundation
public struct Program { public struct Program {
let paramPath: String let paramPath: String
let feedKey: String
let fetchKey: String
let programDesc: ProgramDesc let programDesc: ProgramDesc
let scope: Scope let scope: Scope
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) { init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc) programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
paramPath = inParamPath paramPath = inParamPath
scope = inScope scope = inScope
feedKey = inFeedKey
fetchKey = inFetchKey
} }
} }
...@@ -17,7 +17,13 @@ import Foundation ...@@ -17,7 +17,13 @@ import Foundation
class Scope { class Scope {
var vars: [String : Variant] = [:] var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{ subscript(key: String) -> Variant?{
return vars[key] get {
return vars[key]
}
set {
vars[key] = newValue
}
} }
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
struct Dim { public struct Dim {
init(inDim: [Int]) { init(inDim: [Int]) {
dims = inDim dims = inDim
} }
......
...@@ -14,10 +14,28 @@ ...@@ -14,10 +14,28 @@
import Foundation import Foundation
class Tensor <P: PrecisionType>{ protocol Tensorial {
var dim: Dim { get set }
func numel() -> Int
func dataLayout() -> DataLayout
}
extension Tensorial {
func numel() -> Int {
return dim.numel()
}
}
class Tensor <P: PrecisionType>: Tensorial {
var dim: Dim { var dim: Dim {
return paraData.dim get {
return paraData.dim
}
set {
paraData.dim = newValue
}
} }
let paraData: ParamData<P> let paraData: ParamData<P>
init(inDimArray: [Int], inData: ParamData<P>) { init(inDimArray: [Int], inData: ParamData<P>) {
paraData = inData paraData = inData
...@@ -25,9 +43,6 @@ class Tensor <P: PrecisionType>{ ...@@ -25,9 +43,6 @@ class Tensor <P: PrecisionType>{
init(inData: ParamData<P>) { init(inData: ParamData<P>) {
paraData = inData paraData = inData
} }
func numel() -> Int {
return dim.numel()
}
func dataLayout() -> DataLayout { func dataLayout() -> DataLayout {
return paraData.layout return paraData.layout
......
...@@ -12,8 +12,25 @@ ...@@ -12,8 +12,25 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
import Metal
import Foundation import Foundation
class Texture { public class Texture: Tensorial {
var dim: Dim
// let texture: MTLTexture
func dataLayout() -> DataLayout {
return .NHWC
}
public init(inTexture: MTLTexture, inDim: Dim) {
// texture = inTexture
dim = inDim
}
public init() {
dim = Dim.init(inDim: [])
// fatalError()
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册