未验证 提交 5dfaa551 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #455 from codeWorm2015/metal

fix #454 metal can run ops
......@@ -70,15 +70,13 @@ build
cmake-build-debug
cmake-build-release
# Pods
# metal
Podfile.lock
SwiftProtobuf.framework
metal/Pods/
SwiftProtobuf.framework
paddle-mobile.xcworkspace
metal/models/
......
......@@ -13,9 +13,26 @@
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; };
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC039BC220E11CD00081E9F8 /* test.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BC120E11CD00081E9F8 /* test.pb.swift */; };
FC9D037D20E22E4E000F735A /* params in Resources */ = {isa = PBXBuildFile; fileRef = FC9D037B20E22E4E000F735A /* params */; };
FC9D037E20E22E4E000F735A /* model in Resources */ = {isa = PBXBuildFile; fileRef = FC9D037C20E22E4E000F735A /* model */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
/* End PBXBuildFile section */
/* Begin PBXCopyFilesBuildPhase section */
FCEBEC2E20E1392000C0B14D /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */
081C9CF10DB06C58B8B6B039 /* Pods-paddle-mobile-demo.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-demo.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-demo/Pods-paddle-mobile-demo.release.xcconfig"; sourceTree = "<group>"; };
18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_demo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
......@@ -27,7 +44,9 @@
FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC039BC120E11CD00081E9F8 /* test.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = test.pb.swift; sourceTree = "<group>"; };
FC9D037B20E22E4E000F735A /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FC9D037C20E22E4E000F735A /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
......@@ -35,6 +54,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */,
30D0ED21F392CFA3885B1002 /* Pods_paddle_mobile_demo.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
......@@ -62,6 +82,7 @@
FC039B7520E11C550081E9F8 = {
isa = PBXGroup;
children = (
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */,
FC039B8020E11C550081E9F8 /* paddle-mobile-demo */,
FC039B7F20E11C550081E9F8 /* Products */,
5722B50FEC38F55CA9B6A57B /* Pods */,
......@@ -80,7 +101,7 @@
FC039B8020E11C550081E9F8 /* paddle-mobile-demo */ = {
isa = PBXGroup;
children = (
FC039BC120E11CD00081E9F8 /* test.pb.swift */,
FC9D037A20E22E4E000F735A /* yolo */,
FC039B8120E11C550081E9F8 /* AppDelegate.swift */,
FC039B8320E11C550081E9F8 /* ViewController.swift */,
FC039B8520E11C550081E9F8 /* Main.storyboard */,
......@@ -91,6 +112,16 @@
path = "paddle-mobile-demo";
sourceTree = "<group>";
};
FC9D037A20E22E4E000F735A /* yolo */ = {
isa = PBXGroup;
children = (
FC9D037B20E22E4E000F735A /* params */,
FC9D037C20E22E4E000F735A /* model */,
);
name = yolo;
path = ../../models/yolo;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
......@@ -103,6 +134,7 @@
FC039B7B20E11C550081E9F8 /* Frameworks */,
FC039B7C20E11C550081E9F8 /* Resources */,
84ED590C0E51ABA9C34F51B5 /* [CP] Embed Pods Frameworks */,
FCEBEC2E20E1392000C0B14D /* Embed Frameworks */,
);
buildRules = (
);
......@@ -152,7 +184,9 @@
buildActionMask = 2147483647;
files = (
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */,
FC9D037E20E22E4E000F735A /* model in Resources */,
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */,
FC9D037D20E22E4E000F735A /* params in Resources */,
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
......@@ -204,7 +238,6 @@
buildActionMask = 2147483647;
files = (
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC039BC220E11CD00081E9F8 /* test.pb.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
......
//
// AppDelegate.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/6/25.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import UIKit
......
//
// ViewController.swift
// paddle-mobile-demo
//
// Created by liuRuiLong on 2018/6/25.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import UIKit
import paddle_mobile
......@@ -14,14 +19,17 @@ class ViewController: UIViewController {
override func viewDidLoad() {
super.viewDidLoad()
let loader = Loader<Float>.init()
do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(program: program)
executor.predict()
} catch let error {
print(error)
}
override func didReceiveMemoryWarning() {
super.didReceiveMemoryWarning()
// Dispose of any resources that can be recreated.
}
}
// DO NOT EDIT.
//
// Generated by the Swift generator plugin for the protocol buffer compiler.
// Source: test.proto
//
// For information on using the generated types, please see the documenation:
// https://github.com/apple/swift-protobuf/
import Foundation
import SwiftProtobuf
// If the compiler emits an error on this type, it is because this file
// was generated by a version of the `protoc` Swift plug-in that is
// incompatible with the version of SwiftProtobuf to which you are linking.
// Please ensure that your are building against the same version of the API
// that was used to generate this file.
fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck {
struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {}
typealias Version = _2
}
struct BookInfo {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
var id: Int64 = 0
var title: String = String()
var author: String = String()
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
}
// MARK: - Code below here is support for the SwiftProtobuf runtime.
extension BookInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = "BookInfo"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "id"),
2: .same(proto: "title"),
3: .same(proto: "author"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
switch fieldNumber {
case 1: try decoder.decodeSingularInt64Field(value: &self.id)
case 2: try decoder.decodeSingularStringField(value: &self.title)
case 3: try decoder.decodeSingularStringField(value: &self.author)
default: break
}
}
}
func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if self.id != 0 {
try visitor.visitSingularInt64Field(value: self.id, fieldNumber: 1)
}
if !self.title.isEmpty {
try visitor.visitSingularStringField(value: self.title, fieldNumber: 2)
}
if !self.author.isEmpty {
try visitor.visitSingularStringField(value: self.author, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor)
}
func _protobuf_generated_isEqualTo(other: BookInfo) -> Bool {
if self.id != other.id {return false}
if self.title != other.title {return false}
if self.author != other.author {return false}
if unknownFields != other.unknownFields {return false}
return true
}
}
......@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
......@@ -60,6 +64,10 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
......@@ -115,8 +123,8 @@
FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */,
FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */,
FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */,
FC039B6E20E11C3C0081E9F8 /* Info.plist */,
......@@ -139,6 +147,7 @@
children = (
FC039B9D20E11CB20081E9F8 /* Tensor.swift */,
FC039B9E20E11CB20081E9F8 /* Dim.swift */,
FC9D038320E23B01000F735A /* Texture.swift */,
);
path = framework;
sourceTree = "<group>";
......@@ -151,6 +160,9 @@
FC039BA620E11CBC0081E9F8 /* Operator.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
FC039BA820E11CBC0081E9F8 /* ReluOp.swift */,
FC9D037820E229E4000F735A /* OpParam.swift */,
FC9D037F20E22FBB000F735A /* FeedOp.swift */,
FC9D038120E2312E000F735A /* FetchOp.swift */,
);
path = Operators;
sourceTree = "<group>";
......@@ -271,21 +283,25 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */,
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
......
......@@ -19,4 +19,5 @@ public enum PaddleMobileError: Error{
case netError(message: String)
case memoryError(message: String)
case paramError(message: String)
case opError(message: String)
}
......@@ -20,7 +20,7 @@ precedencegroup ExecutedOrFatalError{
higherThan: AssignmentPrecedence
}
infix operator ?!: ExecutedOrFatalError
func ?!<T>(option: T?, excuteOrError: @autoclosure () -> String) -> T{
public func ?!<T>(option: T?, excuteOrError: @autoclosure () -> String) -> T{
if let inOpt = option {
return inOpt
}else{
......
......@@ -35,16 +35,58 @@ protocol Variant {
extension Tensor: Variant {
}
extension Texture: Variant {
}
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d"
let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
fileprivate var singletons : [String : Any] = [:]
class OpCreator<P: PrecisionType> {
static var shared : OpCreator<P> {
let key = String(describing: P.self)
if let singleton = singletons[key] {
return singleton as! OpCreator<P>
} else {
let newSingleton = OpCreator<P>()
singletons[key] = newSingleton
return newSingleton
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
do {
return try opCreator(opDesc, scope)
} catch let error {
throw error
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat]
private init(){}
}
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])]
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"])]
......@@ -14,6 +14,31 @@
import Foundation
class Executor {
public class Executor<P: PrecisionType> {
var ops: [Runable] = []
public init(program: Program) throws {
for block in program.programDesc.blocks {
for varDesc in block.vars {
if !varDesc.persistable {
program.scope.vars[varDesc.name] = Texture.init()
}
}
for op in block.ops {
do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: program.scope)
ops.append(op)
} catch let error {
throw error
}
}
}
}
public func predict() {
for op in ops {
op.run()
}
}
}
//public let paddle_executor: Executor = Executor.init()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
import Foundation
struct BatchNormParam<P: PrecisionType>: Param {
typealias ParamP = P
struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
......@@ -31,17 +31,25 @@ struct BatchNormParam<P: PrecisionType>: Param {
throw error
}
}
let inputX: Tensor<ParamP>
let outputY: Tensor<ParamP>
let inputBias: Tensor<ParamP>
let inputMean: Tensor<ParamP>
let inputScale: Tensor<ParamP>
let inputVariance: Tensor<ParamP>
let inputX: Texture
let outputY: Texture
let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamPrecisionType>
let inputVariance: Tensor<ParamPrecisionType>
let epsilon: Float
let momentum: Float
let is_test: Bool
}
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>> {
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator{
typealias OpType = BatchNormOp<P>
func runImpl() {
print("this is BatchNormOp")
}
}
......@@ -14,14 +14,14 @@
import Foundation
struct ConvParam<P: PrecisionType>: Param {
typealias ParamP = P
struct ConvParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: scope)
input = try ConvParam.input(inputs: opDesc.inputs, from: scope)
output = try ConvParam.output(outputs: opDesc.outputs, from: scope)
stride = try ConvParam.getAttr(key: "stride", attrs: opDesc.attrs)
stride = try ConvParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
groups = try ConvParam.getAttr(key: "groups", attrs: opDesc.attrs)
......@@ -30,17 +30,18 @@ struct ConvParam<P: PrecisionType>: Param {
}
}
let input: Tensor<ParamP>
let output: Tensor<ParamP>
let filter: Tensor<ParamP>
let stride: [Int]
let paddings: [Int]
let dilations: [Int]
let input: Texture
let output: Texture
let filter: Tensor<ParamPrecisionType>
let stride: [Int32]
let paddings: [Int32]
let dilations: [Int32]
let groups: Int
}
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>> {
override func runImpl() {
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator {
typealias OpType = ConvOp<P>
func runImpl() {
print("this is conv")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ElementwiseAddParam<P: PrecisionType>: Param {
typealias ParamP = P
struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
......@@ -25,14 +26,21 @@ struct ElementwiseAddParam<P: PrecisionType>: Param {
throw error
}
}
let inputX: Tensor<P>
let inputX: Texture
let inputY: Tensor<P>
let out: Tensor<P>
let out: Texture
let axis: Int
}
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>{
override func runImpl() {
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator{
typealias OpType = ElementwiseAddOp<P>
func runImpl() {
print("this is ElementwiseAddOp")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
}
typealias ParamPrecisionType = P
}
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator {
typealias OpType = FeedOp<P>
func runImpl() {
print("feed op")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
}
typealias ParamPrecisionType = P
}
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator {
typealias OpType = FetchOp<P>
func runImpl() {
print("feed op")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
/*
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])]
*/
protocol OpParam {
associatedtype ParamPrecisionType: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputMean<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputScale<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputVariance<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputFilter<VarType: Variant>(paraInputs: [String : [String]], from: Scope) throws -> VarType
static func input<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func output<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func outputY<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func inputY<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func outputOut<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T
}
extension OpParam {
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType {
guard let mapKeys = map[key], mapKeys.count > 0 else {
throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty")
}
guard let variant = from[mapKeys[0]], let v = variant as? VarType else {
throw PaddleMobileError.paramError(message: mapKeys[0] + " not found in scope")
}
return v
}
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorX: VarType = try getFirstTensor(key: "X", map: inputs, from: from)
return tensorX
} catch let error {
throw error
}
}
static func input<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorInput: VarType = try getFirstTensor(key: "Input", map: inputs, from: from)
return tensorInput
} catch let error {
throw error
}
}
static func output<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorOutput: VarType = try getFirstTensor(key: "Output", map: outputs, from: from)
return tensorOutput
} catch let error {
throw error
}
}
static func outputY<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorOutputY: VarType = try getFirstTensor(key: "Y", map: outputs, from: from)
return tensorOutputY
} catch let error {
throw error
}
}
static func inputY<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorY: VarType = try getFirstTensor(key: "Y", map: inputs, from: from)
return tensorY
} catch let error {
throw error
}
}
static func outputOut<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let out: VarType = try getFirstTensor(key: "Out", map: outputs, from: from)
return out
} catch let error {
throw error
}
}
static func inputFilter<VarType: Variant>(paraInputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorFilter: VarType = try getFirstTensor(key: "Filter", map: paraInputs, from: from)
return tensorFilter
} catch let error {
throw error
}
}
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorBias: VarType = try getFirstTensor(key: "Bias", map: inputs, from: from)
return tensorBias
} catch let error {
throw error
}
}
static func inputMean<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorMean: VarType = try getFirstTensor(key: "Mean", map: inputs, from: from)
return tensorMean
} catch let error {
throw error
}
}
static func inputScale<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorScale: VarType = try getFirstTensor(key: "Scale", map: inputs, from: from)
return tensorScale
} catch let error {
throw error
}
}
static func inputVariance<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorVariance: VarType = try getFirstTensor(key: "Variance", map: inputs, from: from)
return tensorVariance
} catch let error {
throw error
}
}
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T{
guard let attr = attrs[key] else {
throw PaddleMobileError.paramError(message: "attr \(key) can't found in: \(attrs)" )
}
guard let tAttr = attr as? T else {
throw PaddleMobileError.paramError(message: "key: \(key) attr: \(attr) type error" )
}
return tAttr
}
}
......@@ -14,152 +14,63 @@
import Foundation
/*
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])]
*/
protocol Param {
associatedtype ParamP: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws
static func getFirstTensor(key: String, map: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputX(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputBiase(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputMean(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputScale(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputVariance(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputFilter(paraInputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func input(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func output(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func outputY(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputY(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func outputOut(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T
protocol Runable {
func run()
func runImpl()
}
extension Param {
static func getFirstTensor(key: String, map: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
guard let mapKeys = map["X"], mapKeys.count > 0, let inputX = from[mapKeys[0]], let tensorX = inputX as? Tensor<ParamP> else {
throw PaddleMobileError.paramError(message: "tensor " + key + "in \(map) not found")
}
return tensorX
}
static func inputX(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorX = try getFirstTensor(key: "X", map: inputs, from: from)
return tensorX
} catch let error {
throw error
}
}
static func input(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorInput = try getFirstTensor(key: "Input", map: inputs, from: from)
return tensorInput
} catch let error {
throw error
}
}
static func output(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorOutput = try getFirstTensor(key: "Output", map: outputs, from: from)
return tensorOutput
} catch let error {
throw error
}
}
static func outputY(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorOutputY = try getFirstTensor(key: "Y", map: outputs, from: from)
return tensorOutputY
} catch let error {
throw error
}
}
static func inputY(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorY = try getFirstTensor(key: "Y", map: inputs, from: from)
return tensorY
} catch let error {
throw error
}
}
static func outputOut(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let out = try getFirstTensor(key: "Out", map: outputs, from: from)
return out
} catch let error {
throw error
}
}
static func inputFilter(paraInputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorFilter = try getFirstTensor(key: "Filter", map: paraInputs, from: from)
return tensorFilter
} catch let error {
throw error
}
extension Runable where Self: OperatorProtocol{
func run() {
runImpl()
}
}
static func inputBiase(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorBias = try getFirstTensor(key: "Bias", map: inputs, from: from)
return tensorBias
} catch let error {
throw error
}
}
protocol Creator where Self: OperatorProtocol{
associatedtype OpType: OperatorProtocol
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
}
static func inputMean(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
extension Creator where Self: OperatorProtocol {
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType {
do {
let tensorMean = try getFirstTensor(key: "Mean", map: inputs, from: from)
return tensorMean
return try OpType.provide(opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
}
}
}
static func inputScale(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorScale = try getFirstTensor(key: "Scale", map: inputs, from: from)
return tensorScale
} catch let error {
throw error
}
}
protocol OperatorProtocol {
associatedtype ParamType: OpParam
var type: String { get }
var inputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get }
var outpus: [String : [String]] { get }
var attrs: [String : Attr] { get }
var para: ParamType { get }
init(opDesc: OpDesc, inScope: Scope) throws
}
static func inputVariance(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
extension OperatorProtocol {
static func provide(opDesc: OpDesc, inScope: Scope) throws -> Self {
do {
let tensorVariance = try getFirstTensor(key: "Variance", map: inputs, from: from)
return tensorVariance
return try Self.init(opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
}
}
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T{
guard let attr = attrs[key] as? T else {
throw PaddleMobileError.paramError(message: "attr type error")
}
return attr
}
}
class Operator<ParamType: Param> {
class Operator <ParameterType: OpParam>: OperatorProtocol{
typealias ParamType = ParameterType
let type: String
let inputs: [String : [String]]
let paraInputs: [String : [String]]
let outpus: [String : [String]]
let attrs: [String : Attr]
let para: ParamType
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
type = opDesc.type
inputs = opDesc.inputs
outpus = opDesc.outputs
......@@ -171,12 +82,4 @@ class Operator<ParamType: Param> {
throw error
}
}
func run() {
runImpl()
}
func runImpl() {
fatalError("runimpl of " + type + "op not implement")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ReluParam<P: PrecisionType>: Param {
typealias ParamP = P
struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws {
do {
inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
......@@ -23,13 +24,14 @@ struct ReluParam<P: PrecisionType>: Param {
throw error
}
}
let inputX: Tensor<ParamP>
let out: Tensor<ParamP>
let inputX: Texture
let out: Texture
}
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>> {
override func runImpl() {
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator{
typealias OpType = ReluOp<P>
func runImpl() {
print("this is ReluOp")
}
}
......
......@@ -34,19 +34,19 @@ struct OpDesc {
}
inputs = creator(protoOpDesc.inputs) {
opInputsOutputsKey[protoOpDesc.type]?.inputs.contains($0) ?? false
opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false
}
paraInputs = creator(protoOpDesc.inputs) {
!(opInputsOutputsKey[protoOpDesc.type]?.inputs.contains($0) ?? false)
!(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false)
}
outputs = creator(protoOpDesc.outputs) {
opInputsOutputsKey[protoOpDesc.type]?.outputs.contains($0) ?? false
opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false
}
unusedOutputs = creator(protoOpDesc.outputs) {
!(opInputsOutputsKey[protoOpDesc.type]?.outputs.contains($0) ?? false)
!(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false)
}
for attr in protoOpDesc.attrs {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class Texture {
}
//
// paddle_mobile.h
// paddle-mobile
//
// Created by liuRuiLong on 2018/6/25.
// Copyright © 2018年 orange. All rights reserved.
//
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import <UIKit/UIKit.h>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册