提交 888877ff 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #460 from codeWorm2015/metal

add infer shape of metal
......@@ -7,7 +7,7 @@
<key>paddle-mobile-demo.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>4</integer>
<integer>3</integer>
</dict>
</dict>
</dict>
......
......@@ -24,8 +24,9 @@ class ViewController: UIViewController {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(program: program)
executor.predict()
let executor = try Executor<Float>.init(inProgram: program)
let output = try executor.predict(input: Texture.init())
print(output)
} catch let error {
print(error)
}
......
......@@ -155,12 +155,12 @@
FC039BA320E11CBC0081E9F8 /* Operators */ = {
isa = PBXGroup;
children = (
FCD592FA20E248EC00252966 /* Base */,
FCD592F920E248EC00252966 /* Kernels */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA620E11CBC0081E9F8 /* Operator.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
FC039BA820E11CBC0081E9F8 /* ReluOp.swift */,
FC9D037820E229E4000F735A /* OpParam.swift */,
FC9D037F20E22FBB000F735A /* FeedOp.swift */,
FC9D038120E2312E000F735A /* FetchOp.swift */,
);
......@@ -183,6 +183,22 @@
path = Program;
sourceTree = "<group>";
};
FCD592F920E248EC00252966 /* Kernels */ = {
isa = PBXGroup;
children = (
);
path = Kernels;
sourceTree = "<group>";
};
FCD592FA20E248EC00252966 /* Base */ = {
isa = PBXGroup;
children = (
FC9D037820E229E4000F735A /* OpParam.swift */,
FC039BA620E11CBC0081E9F8 /* Operator.swift */,
);
path = Base;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXHeadersBuildPhase section */
......
......@@ -7,7 +7,7 @@
<key>paddle-mobile.xcscheme</key>
<dict>
<key>orderHint</key>
<integer>3</integer>
<integer>4</integer>
</dict>
</dict>
</dict>
......
......@@ -60,7 +60,7 @@ class OpCreator<P: PrecisionType> {
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable {
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
......@@ -72,7 +72,7 @@ class OpCreator<P: PrecisionType> {
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable] =
let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
......
......@@ -15,17 +15,15 @@
import Foundation
public class Executor<P: PrecisionType> {
var ops: [Runable] = []
public init(program: Program) throws {
for block in program.programDesc.blocks {
for varDesc in block.vars {
if !varDesc.persistable {
program.scope.vars[varDesc.name] = Texture.init()
}
}
var ops: [Runable & InferShaperable] = []
let program: Program
public init(inProgram: Program) throws {
program = inProgram
for block in inProgram.programDesc.blocks {
for op in block.ops {
do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: program.scope)
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: inProgram.scope)
op.inferShape()
ops.append(op)
} catch let error {
throw error
......@@ -34,10 +32,16 @@ public class Executor<P: PrecisionType> {
}
}
public func predict() {
public func predict(input: Texture) throws -> Texture {
program.scope[program.feedKey] = input
for op in ops {
op.run()
}
let outputVar = program.scope[program.fetchKey]
guard let output = outputVar as? Texture else {
throw PaddleMobileError.netError(message: "output var type error")
}
return output
}
}
......
......@@ -156,14 +156,20 @@ public class Loader<P: PrecisionType> {
let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
serializedData: modelData)
let scope = Scope.init()
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope)
let programDesc = ProgramDesc.init(protoProgram: protoProgram)
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
throw PaddleMobileError.loaderError(message: "load para error")
}
for block in program.programDesc.blocks {
guard programDesc.blocks.count > 0 else {
throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0")
}
for block in programDesc.blocks {
for varDesc in block.vars {
print(varDesc.name + "\(varDesc.persistable)")
if (varDesc.type == .LodTensor) {
if (varDesc.persistable
&& varDesc.type != .FeedMiniBatch
......@@ -192,11 +198,33 @@ public class Loader<P: PrecisionType> {
}
paraData.convert(to: .NHWC)
let tensor = Tensor<P>.init(inData: paraData)
scope.vars[varDesc.name] = tensor
scope[varDesc.name] = tensor
} else {
scope[varDesc.name] = Texture.init()
}
} else {
scope[varDesc.name] = Texture.init()
}
}
}
let block = programDesc.blocks[0]
guard let firstOp = block.ops.first, let lastOp = block.ops.last else {
throw PaddleMobileError.loaderError(message: "at least two operator")
}
guard firstOp.type == gFeedType, lastOp.type == gFetchType else {
throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch")
}
guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else {
throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found")
}
guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
}
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey)
return program
} catch _ {
throw PaddleMobileError.loaderError(message: "protobuf decoder error")
......
......@@ -14,7 +14,6 @@
import Foundation
protocol Runable {
func run()
func runImpl()
......@@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{
}
protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol
associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
}
......@@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol {
}
}
protocol InferShaperable {
func inferShape()
}
protocol OperatorProtocol {
associatedtype ParamType: OpParam
var type: String { get }
......
......@@ -18,8 +18,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
outputY = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope)
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: scope)
inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: scope)
inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: scope)
inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: scope)
......@@ -31,8 +31,8 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
throw error
}
}
let inputX: Texture
let outputY: Texture
let input: Texture
let output: Texture
let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType>
......@@ -42,7 +42,10 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let is_test: Bool
}
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator{
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = BatchNormOp<P>
func runImpl() {
print("this is BatchNormOp")
......
......@@ -39,7 +39,28 @@ struct ConvParam<P: PrecisionType>: OpParam {
let groups: Int
}
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator {
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable {
func inferShape() {
// let inDims = para.input.dim
// let filterDim = para.filter.dim
// let strides = para.stride
// let paddings = para.paddings
// let dilations = para.dilations
//
// var outDim = [inDims[0]]
// for i in 0..<strides.count {
// let dilation: Int = Int(dilations[i])
// let filterSize: Int = filterDim[i + 1]
// let inputSize: Int = inDims[i + 1]
// let padding: Int = Int(paddings[i])
// let stride: Int = Int(strides[i])
// let dKernel = dilation * (filterSize - 1) + 1
// let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
// outDim.append(outputSize)
// }
// outDim.append(filterDim[0])
}
typealias OpType = ConvOp<P>
func runImpl() {
print("this is conv")
......
......@@ -18,21 +18,26 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: scope)
out = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: scope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let inputX: Texture
let input: Texture
let inputY: Tensor<P>
let out: Texture
let output: Texture
let axis: Int
}
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator{
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ElementwiseAddOp<P>
func runImpl() {
print("this is ElementwiseAddOp")
......
......@@ -15,15 +15,28 @@
import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
var output: Texture
var input: Texture
init(opDesc: OpDesc, scope: Scope) throws {
do {
input = try FeedParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error {
throw error
}
}
typealias ParamPrecisionType = P
}
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator {
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl() {
print("feed op")
}
......
......@@ -15,17 +15,29 @@
import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
var input: Texture
var output: Texture
init(opDesc: OpDesc, scope: Scope) throws {
do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope)
output = try FetchParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error {
throw error
}
}
typealias ParamPrecisionType = P
}
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator {
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = FetchOp<P>
func runImpl() {
print("feed op")
print("fetch op")
}
}
......@@ -18,17 +18,22 @@ struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
out = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope)
input = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
output = try ReluParam.outputOut(outputs: opDesc.outputs, from: scope)
} catch let error {
throw error
}
}
let inputX: Texture
let out: Texture
let input: Texture
let output: Texture
}
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator{
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ReluOp<P>
func runImpl() {
print("this is ReluOp")
......
......@@ -16,11 +16,15 @@ import Foundation
public struct Program {
let paramPath: String
let feedKey: String
let fetchKey: String
let programDesc: ProgramDesc
let scope: Scope
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) {
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
paramPath = inParamPath
scope = inScope
feedKey = inFeedKey
fetchKey = inFetchKey
}
}
......@@ -17,7 +17,13 @@ import Foundation
class Scope {
var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{
get {
return vars[key]
}
set {
vars[key] = newValue
}
}
}
......@@ -14,7 +14,7 @@
import Foundation
struct Dim {
public struct Dim {
init(inDim: [Int]) {
dims = inDim
}
......
......@@ -14,10 +14,28 @@
import Foundation
class Tensor <P: PrecisionType>{
protocol Tensorial {
var dim: Dim { get set }
func numel() -> Int
func dataLayout() -> DataLayout
}
extension Tensorial {
func numel() -> Int {
return dim.numel()
}
}
class Tensor <P: PrecisionType>: Tensorial {
var dim: Dim {
get {
return paraData.dim
}
set {
paraData.dim = newValue
}
}
let paraData: ParamData<P>
init(inDimArray: [Int], inData: ParamData<P>) {
paraData = inData
......@@ -25,9 +43,6 @@ class Tensor <P: PrecisionType>{
init(inData: ParamData<P>) {
paraData = inData
}
func numel() -> Int {
return dim.numel()
}
func dataLayout() -> DataLayout {
return paraData.layout
......
......@@ -12,8 +12,25 @@
See the License for the specific language governing permissions and
limitations under the License. */
import Metal
import Foundation
class Texture {
public class Texture: Tensorial {
var dim: Dim
// let texture: MTLTexture
func dataLayout() -> DataLayout {
return .NHWC
}
public init(inTexture: MTLTexture, inDim: Dim) {
// texture = inTexture
dim = inDim
}
public init() {
dim = Dim.init(inDim: [])
// fatalError()
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册