未验证 提交 5dfaa551 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #455 from codeWorm2015/metal

fix #454 metal can run ops
...@@ -70,15 +70,13 @@ build ...@@ -70,15 +70,13 @@ build
cmake-build-debug cmake-build-debug
cmake-build-release cmake-build-release
# Pods
# metal
Podfile.lock Podfile.lock
SwiftProtobuf.framework
metal/Pods/ metal/Pods/
SwiftProtobuf.framework
paddle-mobile.xcworkspace paddle-mobile.xcworkspace
metal/models/
......
...@@ -13,9 +13,26 @@ ...@@ -13,9 +13,26 @@
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; }; FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8520E11C550081E9F8 /* Main.storyboard */; };
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; }; FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; }; FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC039BC220E11CD00081E9F8 /* test.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BC120E11CD00081E9F8 /* test.pb.swift */; }; FC9D037D20E22E4E000F735A /* params in Resources */ = {isa = PBXBuildFile; fileRef = FC9D037B20E22E4E000F735A /* params */; };
FC9D037E20E22E4E000F735A /* model in Resources */ = {isa = PBXBuildFile; fileRef = FC9D037C20E22E4E000F735A /* model */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXCopyFilesBuildPhase section */
FCEBEC2E20E1392000C0B14D /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXCopyFilesBuildPhase section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
081C9CF10DB06C58B8B6B039 /* Pods-paddle-mobile-demo.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-demo.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-demo/Pods-paddle-mobile-demo.release.xcconfig"; sourceTree = "<group>"; }; 081C9CF10DB06C58B8B6B039 /* Pods-paddle-mobile-demo.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-demo.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-demo/Pods-paddle-mobile-demo.release.xcconfig"; sourceTree = "<group>"; };
18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_demo.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 18896810981724F8A0FED62A /* Pods_paddle_mobile_demo.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_demo.framework; sourceTree = BUILT_PRODUCTS_DIR; };
...@@ -27,7 +44,9 @@ ...@@ -27,7 +44,9 @@
FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; }; FC039B8820E11C560081E9F8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; }; FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; }; FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC039BC120E11CD00081E9F8 /* test.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = test.pb.swift; sourceTree = "<group>"; }; FC9D037B20E22E4E000F735A /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FC9D037C20E22E4E000F735A /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */ /* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */ /* Begin PBXFrameworksBuildPhase section */
...@@ -35,6 +54,7 @@ ...@@ -35,6 +54,7 @@
isa = PBXFrameworksBuildPhase; isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */,
30D0ED21F392CFA3885B1002 /* Pods_paddle_mobile_demo.framework in Frameworks */, 30D0ED21F392CFA3885B1002 /* Pods_paddle_mobile_demo.framework in Frameworks */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
...@@ -62,6 +82,7 @@ ...@@ -62,6 +82,7 @@
FC039B7520E11C550081E9F8 = { FC039B7520E11C550081E9F8 = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */,
FC039B8020E11C550081E9F8 /* paddle-mobile-demo */, FC039B8020E11C550081E9F8 /* paddle-mobile-demo */,
FC039B7F20E11C550081E9F8 /* Products */, FC039B7F20E11C550081E9F8 /* Products */,
5722B50FEC38F55CA9B6A57B /* Pods */, 5722B50FEC38F55CA9B6A57B /* Pods */,
...@@ -80,7 +101,7 @@ ...@@ -80,7 +101,7 @@
FC039B8020E11C550081E9F8 /* paddle-mobile-demo */ = { FC039B8020E11C550081E9F8 /* paddle-mobile-demo */ = {
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FC039BC120E11CD00081E9F8 /* test.pb.swift */, FC9D037A20E22E4E000F735A /* yolo */,
FC039B8120E11C550081E9F8 /* AppDelegate.swift */, FC039B8120E11C550081E9F8 /* AppDelegate.swift */,
FC039B8320E11C550081E9F8 /* ViewController.swift */, FC039B8320E11C550081E9F8 /* ViewController.swift */,
FC039B8520E11C550081E9F8 /* Main.storyboard */, FC039B8520E11C550081E9F8 /* Main.storyboard */,
...@@ -91,6 +112,16 @@ ...@@ -91,6 +112,16 @@
path = "paddle-mobile-demo"; path = "paddle-mobile-demo";
sourceTree = "<group>"; sourceTree = "<group>";
}; };
FC9D037A20E22E4E000F735A /* yolo */ = {
isa = PBXGroup;
children = (
FC9D037B20E22E4E000F735A /* params */,
FC9D037C20E22E4E000F735A /* model */,
);
name = yolo;
path = ../../models/yolo;
sourceTree = "<group>";
};
/* End PBXGroup section */ /* End PBXGroup section */
/* Begin PBXNativeTarget section */ /* Begin PBXNativeTarget section */
...@@ -103,6 +134,7 @@ ...@@ -103,6 +134,7 @@
FC039B7B20E11C550081E9F8 /* Frameworks */, FC039B7B20E11C550081E9F8 /* Frameworks */,
FC039B7C20E11C550081E9F8 /* Resources */, FC039B7C20E11C550081E9F8 /* Resources */,
84ED590C0E51ABA9C34F51B5 /* [CP] Embed Pods Frameworks */, 84ED590C0E51ABA9C34F51B5 /* [CP] Embed Pods Frameworks */,
FCEBEC2E20E1392000C0B14D /* Embed Frameworks */,
); );
buildRules = ( buildRules = (
); );
...@@ -152,7 +184,9 @@ ...@@ -152,7 +184,9 @@
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */, FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */,
FC9D037E20E22E4E000F735A /* model in Resources */,
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */, FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */,
FC9D037D20E22E4E000F735A /* params in Resources */,
FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */, FC039B8720E11C550081E9F8 /* Main.storyboard in Resources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
...@@ -204,7 +238,6 @@ ...@@ -204,7 +238,6 @@
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */, FC039B8420E11C550081E9F8 /* ViewController.swift in Sources */,
FC039BC220E11CD00081E9F8 /* test.pb.swift in Sources */,
FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */, FC039B8220E11C550081E9F8 /* AppDelegate.swift in Sources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
......
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// AppDelegate.swift
// paddle-mobile-demo Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/6/25. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import UIKit import UIKit
......
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// ViewController.swift
// paddle-mobile-demo Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/6/25. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import UIKit import UIKit
import paddle_mobile import paddle_mobile
...@@ -14,14 +19,17 @@ class ViewController: UIViewController { ...@@ -14,14 +19,17 @@ class ViewController: UIViewController {
override func viewDidLoad() { override func viewDidLoad() {
super.viewDidLoad() super.viewDidLoad()
let loader = Loader<Float>.init()
do {
let modelPath = Bundle.main.path(forResource: "model", ofType: nil) ?! "model null"
let paraPath = Bundle.main.path(forResource: "params", ofType: nil) ?! "para null"
let program = try loader.load(modelPath: modelPath, paraPath: paraPath)
let executor = try Executor<Float>.init(program: program)
executor.predict()
} catch let error {
print(error)
}
} }
override func didReceiveMemoryWarning() {
super.didReceiveMemoryWarning()
// Dispose of any resources that can be recreated.
}
} }
// DO NOT EDIT.
//
// Generated by the Swift generator plugin for the protocol buffer compiler.
// Source: test.proto
//
// For information on using the generated types, please see the documenation:
// https://github.com/apple/swift-protobuf/
import Foundation
import SwiftProtobuf
// If the compiler emits an error on this type, it is because this file
// was generated by a version of the `protoc` Swift plug-in that is
// incompatible with the version of SwiftProtobuf to which you are linking.
// Please ensure that your are building against the same version of the API
// that was used to generate this file.
fileprivate struct _GeneratedWithProtocGenSwiftVersion: SwiftProtobuf.ProtobufAPIVersionCheck {
struct _2: SwiftProtobuf.ProtobufAPIVersion_2 {}
typealias Version = _2
}
struct BookInfo {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
var id: Int64 = 0
var title: String = String()
var author: String = String()
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
}
// MARK: - Code below here is support for the SwiftProtobuf runtime.
extension BookInfo: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = "BookInfo"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "id"),
2: .same(proto: "title"),
3: .same(proto: "author"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
switch fieldNumber {
case 1: try decoder.decodeSingularInt64Field(value: &self.id)
case 2: try decoder.decodeSingularStringField(value: &self.title)
case 3: try decoder.decodeSingularStringField(value: &self.author)
default: break
}
}
}
func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if self.id != 0 {
try visitor.visitSingularInt64Field(value: self.id, fieldNumber: 1)
}
if !self.title.isEmpty {
try visitor.visitSingularStringField(value: self.title, fieldNumber: 2)
}
if !self.author.isEmpty {
try visitor.visitSingularStringField(value: self.author, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor)
}
func _protobuf_generated_isEqualTo(other: BookInfo) -> Bool {
if self.id != other.id {return false}
if self.title != other.title {return false}
if self.author != other.author {return false}
if unknownFields != other.unknownFields {return false}
return true
}
}
...@@ -30,6 +30,10 @@ ...@@ -30,6 +30,10 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; }; FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; }; FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; }; FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
...@@ -60,6 +64,10 @@ ...@@ -60,6 +64,10 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; }; FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; }; FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; }; FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */ /* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */ /* Begin PBXFrameworksBuildPhase section */
...@@ -115,8 +123,8 @@ ...@@ -115,8 +123,8 @@
FC039BAE20E11CC20081E9F8 /* Program */, FC039BAE20E11CC20081E9F8 /* Program */,
FC039BA320E11CBC0081E9F8 /* Operators */, FC039BA320E11CBC0081E9F8 /* Operators */,
FC039BA120E11CB70081E9F8 /* Loader.swift */, FC039BA120E11CB70081E9F8 /* Loader.swift */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9A20E11CA00081E9F8 /* Executor.swift */, FC039B9A20E11CA00081E9F8 /* Executor.swift */,
FC039B9C20E11CB20081E9F8 /* framework */,
FC039B9320E11C9A0081E9F8 /* Common */, FC039B9320E11C9A0081E9F8 /* Common */,
FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */, FC039B6D20E11C3C0081E9F8 /* paddle_mobile.h */,
FC039B6E20E11C3C0081E9F8 /* Info.plist */, FC039B6E20E11C3C0081E9F8 /* Info.plist */,
...@@ -139,6 +147,7 @@ ...@@ -139,6 +147,7 @@
children = ( children = (
FC039B9D20E11CB20081E9F8 /* Tensor.swift */, FC039B9D20E11CB20081E9F8 /* Tensor.swift */,
FC039B9E20E11CB20081E9F8 /* Dim.swift */, FC039B9E20E11CB20081E9F8 /* Dim.swift */,
FC9D038320E23B01000F735A /* Texture.swift */,
); );
path = framework; path = framework;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -151,6 +160,9 @@ ...@@ -151,6 +160,9 @@
FC039BA620E11CBC0081E9F8 /* Operator.swift */, FC039BA620E11CBC0081E9F8 /* Operator.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */, FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
FC039BA820E11CBC0081E9F8 /* ReluOp.swift */, FC039BA820E11CBC0081E9F8 /* ReluOp.swift */,
FC9D037820E229E4000F735A /* OpParam.swift */,
FC9D037F20E22FBB000F735A /* FeedOp.swift */,
FC9D038120E2312E000F735A /* FetchOp.swift */,
); );
path = Operators; path = Operators;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -271,21 +283,25 @@ ...@@ -271,21 +283,25 @@
isa = PBXSourcesBuildPhase; isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */, FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */, FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */, FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */, FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */, FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */, FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */, FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */, FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */,
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */, FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */, FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */, FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
......
...@@ -19,4 +19,5 @@ public enum PaddleMobileError: Error{ ...@@ -19,4 +19,5 @@ public enum PaddleMobileError: Error{
case netError(message: String) case netError(message: String)
case memoryError(message: String) case memoryError(message: String)
case paramError(message: String) case paramError(message: String)
case opError(message: String)
} }
...@@ -20,7 +20,7 @@ precedencegroup ExecutedOrFatalError{ ...@@ -20,7 +20,7 @@ precedencegroup ExecutedOrFatalError{
higherThan: AssignmentPrecedence higherThan: AssignmentPrecedence
} }
infix operator ?!: ExecutedOrFatalError infix operator ?!: ExecutedOrFatalError
func ?!<T>(option: T?, excuteOrError: @autoclosure () -> String) -> T{ public func ?!<T>(option: T?, excuteOrError: @autoclosure () -> String) -> T{
if let inOpt = option { if let inOpt = option {
return inOpt return inOpt
}else{ }else{
......
...@@ -35,16 +35,58 @@ protocol Variant { ...@@ -35,16 +35,58 @@ protocol Variant {
extension Tensor: Variant { extension Tensor: Variant {
} }
extension Texture: Variant {
}
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d" let gConvType = "conv2d"
let gBatchNormType = "batch_norm" let gBatchNormType = "batch_norm"
let gReluType = "relu" let gReluType = "relu"
let gElementwiseAdd = "elementwise_add" let gElementwiseAdd = "elementwise_add"
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]), fileprivate var singletons : [String : Any] = [:]
gReluType : (inputs: ["X"], outputs: ["Out"]), class OpCreator<P: PrecisionType> {
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])] static var shared : OpCreator<P> {
let key = String(describing: P.self)
if let singleton = singletons[key] {
return singleton as! OpCreator<P>
} else {
let newSingleton = OpCreator<P>()
singletons[key] = newSingleton
return newSingleton
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
do {
return try opCreator(opDesc, scope)
} catch let error {
throw error
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat]
private init(){}
}
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"])]
...@@ -14,6 +14,31 @@ ...@@ -14,6 +14,31 @@
import Foundation import Foundation
class Executor { public class Executor<P: PrecisionType> {
var ops: [Runable] = []
public init(program: Program) throws {
for block in program.programDesc.blocks {
for varDesc in block.vars {
if !varDesc.persistable {
program.scope.vars[varDesc.name] = Texture.init()
}
}
for op in block.ops {
do {
let op = try OpCreator<P>.shared.creat(opDesc: op, scope: program.scope)
ops.append(op)
} catch let error {
throw error
}
}
}
}
public func predict() {
for op in ops {
op.run()
}
}
} }
//public let paddle_executor: Executor = Executor.init()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. ///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License. */
import Foundation import Foundation
struct BatchNormParam<P: PrecisionType>: Param { struct BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamP = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope) inputX = try BatchNormParam.inputX(inputs: opDesc.inputs, from: scope)
...@@ -31,17 +31,25 @@ struct BatchNormParam<P: PrecisionType>: Param { ...@@ -31,17 +31,25 @@ struct BatchNormParam<P: PrecisionType>: Param {
throw error throw error
} }
} }
let inputX: Tensor<ParamP> let inputX: Texture
let outputY: Tensor<ParamP> let outputY: Texture
let inputBias: Tensor<ParamP> let inputBias: Tensor<ParamPrecisionType>
let inputMean: Tensor<ParamP> let inputMean: Tensor<ParamPrecisionType>
let inputScale: Tensor<ParamP> let inputScale: Tensor<ParamPrecisionType>
let inputVariance: Tensor<ParamP> let inputVariance: Tensor<ParamPrecisionType>
let epsilon: Float let epsilon: Float
let momentum: Float let momentum: Float
let is_test: Bool let is_test: Bool
} }
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>> { class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>>, Runable, Creator{
typealias OpType = BatchNormOp<P>
func runImpl() {
print("this is BatchNormOp")
}
} }
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
import Foundation import Foundation
struct ConvParam<P: PrecisionType>: Param { struct ConvParam<P: PrecisionType>: OpParam {
typealias ParamP = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: scope) filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: scope)
input = try ConvParam.input(inputs: opDesc.inputs, from: scope) input = try ConvParam.input(inputs: opDesc.inputs, from: scope)
output = try ConvParam.output(outputs: opDesc.outputs, from: scope) output = try ConvParam.output(outputs: opDesc.outputs, from: scope)
stride = try ConvParam.getAttr(key: "stride", attrs: opDesc.attrs) stride = try ConvParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs) paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs) dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
groups = try ConvParam.getAttr(key: "groups", attrs: opDesc.attrs) groups = try ConvParam.getAttr(key: "groups", attrs: opDesc.attrs)
...@@ -30,17 +30,18 @@ struct ConvParam<P: PrecisionType>: Param { ...@@ -30,17 +30,18 @@ struct ConvParam<P: PrecisionType>: Param {
} }
} }
let input: Tensor<ParamP> let input: Texture
let output: Tensor<ParamP> let output: Texture
let filter: Tensor<ParamP> let filter: Tensor<ParamPrecisionType>
let stride: [Int] let stride: [Int32]
let paddings: [Int] let paddings: [Int32]
let dilations: [Int] let dilations: [Int32]
let groups: Int let groups: Int
} }
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>> { class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator {
override func runImpl() { typealias OpType = ConvOp<P>
func runImpl() {
print("this is conv")
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. ///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License. */
import Foundation import Foundation
struct ElementwiseAddParam<P: PrecisionType>: Param { struct ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamP = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope) inputX = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: scope)
...@@ -25,14 +26,21 @@ struct ElementwiseAddParam<P: PrecisionType>: Param { ...@@ -25,14 +26,21 @@ struct ElementwiseAddParam<P: PrecisionType>: Param {
throw error throw error
} }
} }
let inputX: Tensor<P> let inputX: Texture
let inputY: Tensor<P> let inputY: Tensor<P>
let out: Tensor<P> let out: Texture
let axis: Int let axis: Int
} }
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>{ class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>>, Runable, Creator{
override func runImpl() { typealias OpType = ElementwiseAddOp<P>
func runImpl() {
print("this is ElementwiseAddOp")
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
}
typealias ParamPrecisionType = P
}
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator {
typealias OpType = FeedOp<P>
func runImpl() {
print("feed op")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
init(opDesc: OpDesc, scope: Scope) throws {
}
typealias ParamPrecisionType = P
}
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator {
typealias OpType = FetchOp<P>
func runImpl() {
print("feed op")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
/*
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])]
*/
protocol OpParam {
associatedtype ParamPrecisionType: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputMean<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputScale<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputVariance<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func inputFilter<VarType: Variant>(paraInputs: [String : [String]], from: Scope) throws -> VarType
static func input<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func output<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func outputY<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func inputY<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType
static func outputOut<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T
}
extension OpParam {
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType {
guard let mapKeys = map[key], mapKeys.count > 0 else {
throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty")
}
guard let variant = from[mapKeys[0]], let v = variant as? VarType else {
throw PaddleMobileError.paramError(message: mapKeys[0] + " not found in scope")
}
return v
}
static func inputX<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorX: VarType = try getFirstTensor(key: "X", map: inputs, from: from)
return tensorX
} catch let error {
throw error
}
}
static func input<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorInput: VarType = try getFirstTensor(key: "Input", map: inputs, from: from)
return tensorInput
} catch let error {
throw error
}
}
static func output<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorOutput: VarType = try getFirstTensor(key: "Output", map: outputs, from: from)
return tensorOutput
} catch let error {
throw error
}
}
static func outputY<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorOutputY: VarType = try getFirstTensor(key: "Y", map: outputs, from: from)
return tensorOutputY
} catch let error {
throw error
}
}
static func inputY<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorY: VarType = try getFirstTensor(key: "Y", map: inputs, from: from)
return tensorY
} catch let error {
throw error
}
}
static func outputOut<VarType: Variant>(outputs: [String : [String]], from: Scope) throws -> VarType {
do {
let out: VarType = try getFirstTensor(key: "Out", map: outputs, from: from)
return out
} catch let error {
throw error
}
}
static func inputFilter<VarType: Variant>(paraInputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorFilter: VarType = try getFirstTensor(key: "Filter", map: paraInputs, from: from)
return tensorFilter
} catch let error {
throw error
}
}
static func inputBiase<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorBias: VarType = try getFirstTensor(key: "Bias", map: inputs, from: from)
return tensorBias
} catch let error {
throw error
}
}
static func inputMean<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorMean: VarType = try getFirstTensor(key: "Mean", map: inputs, from: from)
return tensorMean
} catch let error {
throw error
}
}
static func inputScale<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorScale: VarType = try getFirstTensor(key: "Scale", map: inputs, from: from)
return tensorScale
} catch let error {
throw error
}
}
static func inputVariance<VarType: Variant>(inputs: [String : [String]], from: Scope) throws -> VarType {
do {
let tensorVariance: VarType = try getFirstTensor(key: "Variance", map: inputs, from: from)
return tensorVariance
} catch let error {
throw error
}
}
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T{
guard let attr = attrs[key] else {
throw PaddleMobileError.paramError(message: "attr \(key) can't found in: \(attrs)" )
}
guard let tAttr = attr as? T else {
throw PaddleMobileError.paramError(message: "key: \(key) attr: \(attr) type error" )
}
return tAttr
}
}
...@@ -14,152 +14,63 @@ ...@@ -14,152 +14,63 @@
import Foundation import Foundation
/*
let opInputsOutputsKey = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"])]
*/
protocol Param { protocol Runable {
associatedtype ParamP: PrecisionType func run()
init(opDesc: OpDesc, scope: Scope) throws func runImpl()
static func getFirstTensor(key: String, map: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputX(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputBiase(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputMean(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputScale(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputVariance(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputFilter(paraInputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func input(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func output(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func outputY(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func inputY(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func outputOut(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP>
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T
} }
extension Param { extension Runable where Self: OperatorProtocol{
static func getFirstTensor(key: String, map: [String : [String]], from: Scope) throws -> Tensor<ParamP> { func run() {
guard let mapKeys = map["X"], mapKeys.count > 0, let inputX = from[mapKeys[0]], let tensorX = inputX as? Tensor<ParamP> else { runImpl()
throw PaddleMobileError.paramError(message: "tensor " + key + "in \(map) not found")
}
return tensorX
}
static func inputX(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorX = try getFirstTensor(key: "X", map: inputs, from: from)
return tensorX
} catch let error {
throw error
}
}
static func input(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorInput = try getFirstTensor(key: "Input", map: inputs, from: from)
return tensorInput
} catch let error {
throw error
}
}
static func output(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorOutput = try getFirstTensor(key: "Output", map: outputs, from: from)
return tensorOutput
} catch let error {
throw error
}
}
static func outputY(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorOutputY = try getFirstTensor(key: "Y", map: outputs, from: from)
return tensorOutputY
} catch let error {
throw error
}
}
static func inputY(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorY = try getFirstTensor(key: "Y", map: inputs, from: from)
return tensorY
} catch let error {
throw error
}
}
static func outputOut(outputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let out = try getFirstTensor(key: "Out", map: outputs, from: from)
return out
} catch let error {
throw error
}
}
static func inputFilter(paraInputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorFilter = try getFirstTensor(key: "Filter", map: paraInputs, from: from)
return tensorFilter
} catch let error {
throw error
}
}
static func inputBiase(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do {
let tensorBias = try getFirstTensor(key: "Bias", map: inputs, from: from)
return tensorBias
} catch let error {
throw error
}
} }
}
static func inputMean(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> { protocol Creator where Self: OperatorProtocol{
do { associatedtype OpType: OperatorProtocol
let tensorMean = try getFirstTensor(key: "Mean", map: inputs, from: from) static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
return tensorMean }
} catch let error {
throw error extension Creator where Self: OperatorProtocol {
} static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType {
}
static func inputScale(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
do { do {
let tensorScale = try getFirstTensor(key: "Scale", map: inputs, from: from) return try OpType.provide(opDesc: opDesc, inScope: inScope)
return tensorScale
} catch let error { } catch let error {
throw error throw error
} }
} }
}
static func inputVariance(inputs: [String : [String]], from: Scope) throws -> Tensor<ParamP> {
protocol OperatorProtocol {
associatedtype ParamType: OpParam
var type: String { get }
var inputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get }
var outpus: [String : [String]] { get }
var attrs: [String : Attr] { get }
var para: ParamType { get }
init(opDesc: OpDesc, inScope: Scope) throws
}
extension OperatorProtocol {
static func provide(opDesc: OpDesc, inScope: Scope) throws -> Self {
do { do {
let tensorVariance = try getFirstTensor(key: "Variance", map: inputs, from: from) return try Self.init(opDesc: opDesc, inScope: inScope)
return tensorVariance
} catch let error { } catch let error {
throw error throw error
} }
} }
static func getAttr<T>(key: String, attrs: [String : Attr]) throws -> T{
guard let attr = attrs[key] as? T else {
throw PaddleMobileError.paramError(message: "attr type error")
}
return attr
}
} }
class Operator <ParameterType: OpParam>: OperatorProtocol{
class Operator<ParamType: Param> { typealias ParamType = ParameterType
let type: String let type: String
let inputs: [String : [String]] let inputs: [String : [String]]
let paraInputs: [String : [String]] let paraInputs: [String : [String]]
let outpus: [String : [String]] let outpus: [String : [String]]
let attrs: [String : Attr] let attrs: [String : Attr]
let para: ParamType let para: ParamType
init(opDesc: OpDesc, inScope: Scope) throws { required init(opDesc: OpDesc, inScope: Scope) throws {
type = opDesc.type type = opDesc.type
inputs = opDesc.inputs inputs = opDesc.inputs
outpus = opDesc.outputs outpus = opDesc.outputs
...@@ -171,12 +82,4 @@ class Operator<ParamType: Param> { ...@@ -171,12 +82,4 @@ class Operator<ParamType: Param> {
throw error throw error
} }
} }
func run() {
runImpl()
}
func runImpl() {
fatalError("runimpl of " + type + "op not implement")
}
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. ///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License. */
import Foundation import Foundation
struct ReluParam<P: PrecisionType>: Param { struct ReluParam<P: PrecisionType>: OpParam {
typealias ParamP = P typealias ParamPrecisionType = P
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope) inputX = try ReluParam.inputX(inputs: opDesc.inputs, from: scope)
...@@ -23,13 +24,14 @@ struct ReluParam<P: PrecisionType>: Param { ...@@ -23,13 +24,14 @@ struct ReluParam<P: PrecisionType>: Param {
throw error throw error
} }
} }
let inputX: Tensor<ParamP> let inputX: Texture
let out: Tensor<ParamP> let out: Texture
} }
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>> { class ReluOp<P: PrecisionType>: Operator<ReluParam<P>>, Runable, Creator{
override func runImpl() { typealias OpType = ReluOp<P>
func runImpl() {
print("this is ReluOp")
} }
} }
......
...@@ -34,19 +34,19 @@ struct OpDesc { ...@@ -34,19 +34,19 @@ struct OpDesc {
} }
inputs = creator(protoOpDesc.inputs) { inputs = creator(protoOpDesc.inputs) {
opInputsOutputsKey[protoOpDesc.type]?.inputs.contains($0) ?? false opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false
} }
paraInputs = creator(protoOpDesc.inputs) { paraInputs = creator(protoOpDesc.inputs) {
!(opInputsOutputsKey[protoOpDesc.type]?.inputs.contains($0) ?? false) !(opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false)
} }
outputs = creator(protoOpDesc.outputs) { outputs = creator(protoOpDesc.outputs) {
opInputsOutputsKey[protoOpDesc.type]?.outputs.contains($0) ?? false opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false
} }
unusedOutputs = creator(protoOpDesc.outputs) { unusedOutputs = creator(protoOpDesc.outputs) {
!(opInputsOutputsKey[protoOpDesc.type]?.outputs.contains($0) ?? false) !(opInfos[protoOpDesc.type]?.outputs.contains($0) ?? false)
} }
for attr in protoOpDesc.attrs { for attr in protoOpDesc.attrs {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class Texture {
}
// /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// paddle_mobile.h
// paddle-mobile Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// Created by liuRuiLong on 2018/6/25. You may obtain a copy of the License at
// Copyright © 2018年 orange. All rights reserved.
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#import <UIKit/UIKit.h> #import <UIKit/UIKit.h>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册