.creat, gBatchNormType : BatchNormOp
.creat, gReluType : ReluOp
.creat,
diff --git a/metal/paddle-mobile/paddle-mobile/Executor.swift b/metal/paddle-mobile/paddle-mobile/Executor.swift
index 7274e7c0d86fd8d44bd83f515f68ca3c4932742b..92370bd724378e412dd34344c8769f21b7fc5d3e 100644
--- a/metal/paddle-mobile/paddle-mobile/Executor.swift
+++ b/metal/paddle-mobile/paddle-mobile/Executor.swift
@@ -15,17 +15,15 @@
import Foundation
public class Executor .shared.creat(opDesc: op, scope: program.scope)
+ let op = try OpCreator .shared.creat(opDesc: op, scope: inProgram.scope)
+ op.inferShape()
ops.append(op)
} catch let error {
throw error
@@ -34,10 +32,16 @@ public class Executor .init(inData: paraData)
- scope.vars[varDesc.name] = tensor
+ scope[varDesc.name] = tensor
+ } else {
+ scope[varDesc.name] = Texture.init()
}
+ } else {
+ scope[varDesc.name] = Texture.init()
}
}
}
+
+ let block = programDesc.blocks[0]
+ guard let firstOp = block.ops.first, let lastOp = block.ops.last else {
+ throw PaddleMobileError.loaderError(message: "at least two operator")
+ }
+ guard firstOp.type == gFeedType, lastOp.type == gFetchType else {
+ throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch")
+ }
+
+ guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else {
+ throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found")
+ }
+ guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else {
+ throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
+ }
+
+ let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey)
+
return program
} catch _ {
throw PaddleMobileError.loaderError(message: "protobuf decoder error")
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift
similarity index 100%
rename from metal/paddle-mobile/paddle-mobile/Operators/OpParam.swift
rename to metal/paddle-mobile/paddle-mobile/Operators/Base/OpParam.swift
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
similarity index 95%
rename from metal/paddle-mobile/paddle-mobile/Operators/Operator.swift
rename to metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
index 05434c2b956423e32d539dfba417f245c3037ab6..7255c73773c35484fa392373ba5d45a654d06b45 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/Operator.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/Base/Operator.swift
@@ -14,7 +14,6 @@
import Foundation
-
protocol Runable {
func run()
func runImpl()
@@ -27,7 +26,7 @@ extension Runable where Self: OperatorProtocol{
}
protocol Creator where Self: OperatorProtocol{
- associatedtype OpType: OperatorProtocol
+ associatedtype OpType: OperatorProtocol & Runable & InferShaperable
static func creat(opDesc: OpDesc, inScope: Scope) throws -> OpType
}
@@ -41,6 +40,11 @@ extension Creator where Self: OperatorProtocol {
}
}
+protocol InferShaperable {
+ func inferShape()
+}
+
+
protocol OperatorProtocol {
associatedtype ParamType: OpParam
var type: String { get }
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
index 706e8002607112570beb07835ca1d28b7fc3b959..7a5733a70a854d972e477a0c5f57f3da517dcd22 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/BatchNormOp.swift
@@ -18,8 +18,8 @@ struct BatchNormParam
func runImpl() {
print("this is BatchNormOp")
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
index 9d3d0d8fed518c0954d452a3a162727c70a6efb8..69fa76fd3a62ac875347f6c19ada6a88a579f809 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/ConvOp.swift
@@ -39,7 +39,28 @@ struct ConvParam
- let out: Texture
+ let output: Texture
let axis: Int
}
-class ElementwiseAddOp
func runImpl() {
print("this is ElementwiseAddOp")
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
index c82aa10d773ac2673c563b5f1a798ae7f8c67999..8c4bee3aa6ca6b6c80b4f2d25d6d0c7415718c86 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/FeedOp.swift
@@ -15,15 +15,28 @@
import Foundation
struct FeedParam
+
+ func inferShape() {
+ para.output.dim = para.input.dim
+ }
+
func runImpl() {
print("feed op")
}
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
index 5e17bfefe2abd083556e36419c7ede4be88c6c3f..bd333bff0ae39bfcbe57b0d55eb91bcd786f5dad 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/FetchOp.swift
@@ -15,17 +15,29 @@
import Foundation
struct FetchParam
func runImpl() {
- print("feed op")
+ print("fetch op")
}
}
diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
index 59fff1c18633f8423c28aceb4f9bbff02aac4c4b..d641d13ecbcf1f4e70a0ff056eab5c68b4fc30ae 100644
--- a/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
+++ b/metal/paddle-mobile/paddle-mobile/Operators/ReluOp.swift
@@ -18,17 +18,22 @@ struct ReluParam
func runImpl() {
print("this is ReluOp")
diff --git a/metal/paddle-mobile/paddle-mobile/Program/Program.swift b/metal/paddle-mobile/paddle-mobile/Program/Program.swift
index a346af8304688ba766928c6c6132dc7b840e683d..5a93e79338b5d694178313064cf49cfbb0d969db 100644
--- a/metal/paddle-mobile/paddle-mobile/Program/Program.swift
+++ b/metal/paddle-mobile/paddle-mobile/Program/Program.swift
@@ -16,11 +16,15 @@ import Foundation
public struct Program {
let paramPath: String
+ let feedKey: String
+ let fetchKey: String
let programDesc: ProgramDesc
let scope: Scope
- init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) {
+ init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope, inFeedKey: String, inFetchKey: String) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
paramPath = inParamPath
scope = inScope
+ feedKey = inFeedKey
+ fetchKey = inFetchKey
}
}
diff --git a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift
index 0f34ed20ddf7feed4e22c8ce330a3343a41edc93..7ad95fa5357941c5c24163b37a7eaf6961e4be7e 100644
--- a/metal/paddle-mobile/paddle-mobile/Program/Scope.swift
+++ b/metal/paddle-mobile/paddle-mobile/Program/Scope.swift
@@ -17,7 +17,13 @@ import Foundation
class Scope {
var vars: [String : Variant] = [:]
subscript(key: String) -> Variant?{
- return vars[key]
+ get {
+ return vars[key]
+ }
+ set {
+ vars[key] = newValue
+ }
+
}
}
diff --git a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift
index fa96f23b74832b0244535ddf28b221d5fdbb9bc7..eb5c57a5611eafbb4fc50742f21f22e5b7a57903 100644
--- a/metal/paddle-mobile/paddle-mobile/framework/Dim.swift
+++ b/metal/paddle-mobile/paddle-mobile/framework/Dim.swift
@@ -14,7 +14,7 @@
import Foundation
-struct Dim {
+public struct Dim {
init(inDim: [Int]) {
dims = inDim
}
diff --git a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
index 7d00e9b8687674cf955910f3d208fb3140a7032f..8183f80523bf903373b4b99e1f4460ce76367477 100644
--- a/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
+++ b/metal/paddle-mobile/paddle-mobile/framework/Tensor.swift
@@ -14,10 +14,28 @@
import Foundation
-class Tensor
init(inDimArray: [Int], inData: ParamData ) {
paraData = inData
@@ -25,9 +43,6 @@ class Tensor ) {
paraData = inData
}
- func numel() -> Int {
- return dim.numel()
- }
func dataLayout() -> DataLayout {
return paraData.layout
diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift
index 5acd14fbb0a8e8a6bed24b73ab8315b11a87f817..e62fb1d8d7ed10d2cf7ca94f6a3df6496090f07f 100644
--- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift
+++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift
@@ -12,8 +12,25 @@
See the License for the specific language governing permissions and
limitations under the License. */
+import Metal
import Foundation
-class Texture {
+public class Texture: Tensorial {
+ var dim: Dim
+// let texture: MTLTexture
+ func dataLayout() -> DataLayout {
+ return .NHWC
+ }
+
+ public init(inTexture: MTLTexture, inDim: Dim) {
+// texture = inTexture
+ dim = inDim
+ }
+
+ public init() {
+ dim = Dim.init(inDim: [])
+
+// fatalError()
+ }
}