提交 5a5927ac 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #467 from codeWorm2015/metal

complete Skeleton Code
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; }; FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB520E11CC20081E9F8 /* OpDesc.swift */; };
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; }; FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB620E11CC20081E9F8 /* Attribute.swift */; };
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; }; FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC039BB720E11CC20081E9F8 /* BlockDesc.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; }; FC9D037920E229E4000F735A /* OpParam.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037820E229E4000F735A /* OpParam.swift */; };
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; }; FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; }; FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
...@@ -64,6 +65,7 @@ ...@@ -64,6 +65,7 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; }; FC039BB520E11CC20081E9F8 /* OpDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpDesc.swift; sourceTree = "<group>"; };
FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; }; FC039BB620E11CC20081E9F8 /* Attribute.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Attribute.swift; sourceTree = "<group>"; };
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; }; FC039BB720E11CC20081E9F8 /* BlockDesc.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BlockDesc.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; }; FC9D037820E229E4000F735A /* OpParam.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpParam.swift; sourceTree = "<group>"; };
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; }; FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; }; FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
...@@ -195,6 +197,7 @@ ...@@ -195,6 +197,7 @@
children = ( children = (
FC9D037820E229E4000F735A /* OpParam.swift */, FC9D037820E229E4000F735A /* OpParam.swift */,
FC039BA620E11CBC0081E9F8 /* Operator.swift */, FC039BA620E11CBC0081E9F8 /* Operator.swift */,
FC82735820E3C04200BE430A /* OpCreator.swift */,
); );
path = Base; path = Base;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -316,6 +319,7 @@ ...@@ -316,6 +319,7 @@
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */, FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */, FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */, FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
......
...@@ -24,69 +24,19 @@ public protocol PrecisionType { ...@@ -24,69 +24,19 @@ public protocol PrecisionType {
extension Float32: PrecisionType { extension Float32: PrecisionType {
} }
enum DataLayout { public enum DataLayout {
case NCHW case NCHW
case NHWC case NHWC
} }
protocol Variant { protocol Variant: CustomStringConvertible, CustomDebugStringConvertible {
} }
extension Tensor: Variant { extension Tensor: Variant {
} }
extension Texture: Variant { extension Texture: Variant {
} }
extension ResultHolder: Variant {
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d"
let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
fileprivate var singletons : [String : Any] = [:]
class OpCreator<P: PrecisionType> {
static var shared : OpCreator<P> {
let key = String(describing: P.self)
if let singleton = singletons[key] {
return singleton as! OpCreator<P>
} else {
let newSingleton = OpCreator<P>()
singletons[key] = newSingleton
return newSingleton
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
do {
return try opCreator(opDesc, scope)
} catch let error {
throw error
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat]
private init(){}
} }
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"])]
...@@ -14,9 +14,41 @@ ...@@ -14,9 +14,41 @@
import Foundation import Foundation
public class ResultHolder<P: PrecisionType> {
public let dim: [Int]
public let resultArr: [P]
public init(inDim: [Int], inResult: [P]) {
dim = inDim
resultArr = inResult
}
}
extension ResultHolder: CustomDebugStringConvertible, CustomStringConvertible {
public var debugDescription: String {
var str = ""
str += "Dim: \(dim) \n value:[ "
if resultArr.count < 20 {
for d in resultArr {
str += " \(d) "
}
} else {
for d in stride(from: 0, to: resultArr.count, by: resultArr.count/20) {
str += " \(resultArr[d]) "
}
}
str += " ]"
return str
}
public var description: String {
return debugDescription
}
}
public class Executor<P: PrecisionType> { public class Executor<P: PrecisionType> {
var ops: [Runable & InferShaperable] = [] var ops: [Runable & InferShaperable] = []
let program: Program let program: Program
public init(inProgram: Program) throws { public init(inProgram: Program) throws {
program = inProgram program = inProgram
for block in inProgram.programDesc.blocks { for block in inProgram.programDesc.blocks {
...@@ -32,13 +64,13 @@ public class Executor<P: PrecisionType> { ...@@ -32,13 +64,13 @@ public class Executor<P: PrecisionType> {
} }
} }
public func predict(input: Texture) throws -> Texture { public func predict(input: Texture) throws -> ResultHolder<P> {
program.scope[program.feedKey] = input program.scope[program.feedKey] = input
for op in ops { for op in ops {
op.run() op.run()
} }
let outputVar = program.scope[program.fetchKey] let outputVar = program.scope[program.fetchKey]
guard let output = outputVar as? Texture else { guard let output = outputVar as? ResultHolder<P> else {
throw PaddleMobileError.netError(message: "output var type error") throw PaddleMobileError.netError(message: "output var type error")
} }
return output return output
......
...@@ -15,72 +15,6 @@ ...@@ -15,72 +15,6 @@
import Foundation import Foundation
import SwiftProtobuf import SwiftProtobuf
class ParamData<P: PrecisionType> {
let size: Int
var dim: Dim
private(set) var layout: DataLayout
var pointer: UnsafeMutablePointer<P>
init(inDim: Dim, inLayout: DataLayout = .NCHW) {
dim = inDim
size = inDim.numel() * MemoryLayout<P>.size
pointer = UnsafeMutablePointer<P>.allocate(capacity: size)
layout = inLayout
}
func convert(to: DataLayout) {
guard to != layout else {
return
}
guard dim.cout() == 4 else {
return
}
guard layout == .NCHW && to == .NHWC else {
// other not support
return
}
let newPointer = UnsafeMutablePointer<P>.allocate(capacity: size)
if layout == .NCHW {
NCHW2NHWC(newPtr: newPointer)
}
pointer.deinitialize(count: size)
pointer.deallocate()
pointer = newPointer
layout = to
}
func NCHW2NHWC(newPtr: UnsafeMutablePointer<P>) {
let N = dim[0]
let C = dim[1]
let H = dim[2]
let W = dim[3]
let HXW = H * W
let CXHXW = C * H * W
var index: Int = 0
for n in 0..<N {
for h in 0..<H{
for w in 0..<W{
for c in 0..<C{
newPtr[index] = pointer[n * CXHXW + c * HXW + h * w + w]
index += 1
}
}
}
}
dim.swapeDimAt(index1: 1, index2: 3)
}
deinit {
pointer.deinitialize(count: size)
pointer.deallocate()
}
}
public class Loader<P: PrecisionType> { public class Loader<P: PrecisionType> {
class ParaLoader { class ParaLoader {
...@@ -101,7 +35,7 @@ public class Loader<P: PrecisionType> { ...@@ -101,7 +35,7 @@ public class Loader<P: PrecisionType> {
nowIndex = 0 nowIndex = 0
} }
func read(data: ParamData<P>) throws { func read(tensor: Tensor<P>) throws {
guard nowIndex <= fileSize else { guard nowIndex <= fileSize else {
throw PaddleMobileError.loaderError(message: "out of the file range") throw PaddleMobileError.loaderError(message: "out of the file range")
} }
...@@ -135,8 +69,8 @@ public class Loader<P: PrecisionType> { ...@@ -135,8 +69,8 @@ public class Loader<P: PrecisionType> {
这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度 这里没有根据 Data Type 去判断, 而是从外部泛型直接指定了精度
*/ */
let bytesRead = fread(data.pointer, 1, data.size, file) let bytesRead = fread(tensor.data.pointer, 1, tensor.data.size, file)
guard bytesRead == data.size else { guard bytesRead == tensor.data.size else {
throw PaddleMobileError.loaderError(message: "param read size error") throw PaddleMobileError.loaderError(message: "param read size error")
} }
nowIndex += bytesRead nowIndex += bytesRead
...@@ -166,63 +100,66 @@ public class Loader<P: PrecisionType> { ...@@ -166,63 +100,66 @@ public class Loader<P: PrecisionType> {
throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0") throw PaddleMobileError.loaderError(message: "count of blocks must greater than 0")
} }
// to get feed key and fetch key
let block = programDesc.blocks[0]
guard let firstOp = block.ops.first, let lastOp = block.ops.last else {
throw PaddleMobileError.loaderError(message: "at least two operator")
}
guard firstOp.type == gFeedType, lastOp.type == gFetchType else {
throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch")
}
guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else {
throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found")
}
guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
}
// to load memory
for block in programDesc.blocks { for block in programDesc.blocks {
for varDesc in block.vars { for varDesc in block.vars {
print(varDesc.name + "\(varDesc.persistable)")
if (varDesc.type == .LodTensor) { if (varDesc.type == .LodTensor) {
guard let tensorDesc = varDesc.tensorDesc else {
throw PaddleMobileError.loaderError(message: "get tensor desc failed")
}
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
throw PaddleMobileError.memoryError(message: "PrecisionType not support")
}
if (varDesc.persistable if (varDesc.persistable
&& varDesc.type != .FeedMiniBatch && varDesc.type != .FeedMiniBatch
&& varDesc.type != .FetchList) { && varDesc.type != .FetchList) {
guard let tensorDesc = varDesc.tensorDesc else {
throw PaddleMobileError.loaderError(message: "get tensor desc failed")
}
guard (try? tensorDesc.dataType.dataTypeSize()) == MemoryLayout<P>.size else {
throw PaddleMobileError.memoryError(message: "PrecisionType not support")
}
let dimArr = tensorDesc.dims let dimArr = tensorDesc.dims
guard dimArr.count > 0 else { guard dimArr.count > 0 else {
throw PaddleMobileError.loaderError(message: "tensor desc dim size error") throw PaddleMobileError.loaderError(message: "tensor desc dim size error")
} }
let dim = Dim.init(inDim: dimArr) let dim = Dim.init(inDim: dimArr)
let paraData = ParamData<P>.init(inDim: dim) let tensor = Tensor<P>.init(inDim: dim, inLayout: tensorDesc.dataLayout)
do { do {
try paraLoader.read(data: paraData) try paraLoader.read(tensor: tensor)
} catch let error { } catch let error {
throw error throw error
} }
paraData.convert(to: .NHWC) tensor.convert(to: .NHWC)
let tensor = Tensor<P>.init(inData: paraData)
scope[varDesc.name] = tensor scope[varDesc.name] = tensor
} else { } else {
scope[varDesc.name] = Texture.init() let dim = Dim.init(inDim: tensorDesc.NHWCDim)
scope[varDesc.name] = Texture.init(inDim: dim, inLayout: .NHWC)
} }
} else { } else {
scope[varDesc.name] = Texture.init() if varDesc.name == fetchKey {
scope[varDesc.name] = ResultHolder<P>.init(inDim: [], inResult: [])
} else if varDesc.name == feedKey {
scope[varDesc.name] = Texture.init()
}
} }
} }
} }
let block = programDesc.blocks[0]
guard let firstOp = block.ops.first, let lastOp = block.ops.last else {
throw PaddleMobileError.loaderError(message: "at least two operator")
}
guard firstOp.type == gFeedType, lastOp.type == gFetchType else {
throw PaddleMobileError.loaderError(message: "the first op is not feed or the last op is not fetch")
}
guard let inputKey = opInfos[gFeedType]?.inputs.first, let outKey = opInfos[gFetchType]?.outputs.first else {
throw PaddleMobileError.loaderError(message: "the feed input key or fetch output key not found")
}
guard let feedKey = firstOp.inputs[inputKey]?.first, let fetchKey = lastOp.outputs[outKey]?.first else {
throw PaddleMobileError.loaderError(message: "feed key or fetch key not found")
}
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey) let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope, inFeedKey: feedKey, inFetchKey: fetchKey)
return program return program
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
fileprivate var singletons : [String : Any] = [:]
class OpCreator<P: PrecisionType> {
static var shared : OpCreator<P> {
let key = String(describing: P.self)
if let singleton = singletons[key] {
return singleton as! OpCreator<P>
} else {
let newSingleton = OpCreator<P>()
singletons[key] = newSingleton
return newSingleton
}
}
func creat(opDesc: OpDesc, scope: Scope) throws -> Runable & InferShaperable {
guard let opCreator = opCreators[opDesc.type] else {
throw PaddleMobileError.opError(message: "there is no " + opDesc.type + " yet")
}
do {
return try opCreator(opDesc, scope)
} catch let error {
throw error
}
}
let opCreators: [String : (OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat]
private init(){}
}
...@@ -22,6 +22,10 @@ import Foundation ...@@ -22,6 +22,10 @@ import Foundation
*/ */
protocol OpParam { protocol OpParam {
associatedtype OutputType: Variant
var output: OutputType { get }
func outputDesc() -> String
associatedtype ParamPrecisionType: PrecisionType associatedtype ParamPrecisionType: PrecisionType
init(opDesc: OpDesc, scope: Scope) throws init(opDesc: OpDesc, scope: Scope) throws
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType
...@@ -40,6 +44,10 @@ protocol OpParam { ...@@ -40,6 +44,10 @@ protocol OpParam {
} }
extension OpParam { extension OpParam {
func outputDesc() -> String {
return output.debugDescription
}
static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType { static func getFirstTensor<VarType: Variant>(key: String, map: [String : [String]], from: Scope) throws -> VarType {
guard let mapKeys = map[key], mapKeys.count > 0 else { guard let mapKeys = map[key], mapKeys.count > 0 else {
throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty") throw PaddleMobileError.paramError(message: key + " not found in \(map) or maped values is empty")
......
...@@ -22,6 +22,7 @@ protocol Runable { ...@@ -22,6 +22,7 @@ protocol Runable {
extension Runable where Self: OperatorProtocol{ extension Runable where Self: OperatorProtocol{
func run() { func run() {
runImpl() runImpl()
print(type + ": " + para.outputDesc())
} }
} }
...@@ -44,7 +45,6 @@ protocol InferShaperable { ...@@ -44,7 +45,6 @@ protocol InferShaperable {
func inferShape() func inferShape()
} }
protocol OperatorProtocol { protocol OperatorProtocol {
associatedtype ParamType: OpParam associatedtype ParamType: OpParam
var type: String { get } var type: String { get }
...@@ -66,6 +66,7 @@ extension OperatorProtocol { ...@@ -66,6 +66,7 @@ extension OperatorProtocol {
} }
} }
class Operator <ParameterType: OpParam>: OperatorProtocol{ class Operator <ParameterType: OpParam>: OperatorProtocol{
typealias ParamType = ParameterType typealias ParamType = ParameterType
let type: String let type: String
...@@ -87,3 +88,19 @@ class Operator <ParameterType: OpParam>: OperatorProtocol{ ...@@ -87,3 +88,19 @@ class Operator <ParameterType: OpParam>: OperatorProtocol{
} }
} }
} }
// op infos
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d"
let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"])]
...@@ -41,24 +41,25 @@ struct ConvParam<P: PrecisionType>: OpParam { ...@@ -41,24 +41,25 @@ struct ConvParam<P: PrecisionType>: OpParam {
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable { class ConvOp<P: PrecisionType>: Operator<ConvParam<P>>, Runable, Creator, InferShaperable {
func inferShape() { func inferShape() {
// let inDims = para.input.dim let inDims = para.input.dim
// let filterDim = para.filter.dim let filterDim = para.filter.dim
// let strides = para.stride let strides = para.stride
// let paddings = para.paddings let paddings = para.paddings
// let dilations = para.dilations let dilations = para.dilations
//
// var outDim = [inDims[0]] var outDim = [inDims[0]]
// for i in 0..<strides.count { for i in 0..<strides.count {
// let dilation: Int = Int(dilations[i]) let dilation: Int = Int(dilations[i])
// let filterSize: Int = filterDim[i + 1] let filterSize: Int = filterDim[i + 1]
// let inputSize: Int = inDims[i + 1] let inputSize: Int = inDims[i + 1]
// let padding: Int = Int(paddings[i]) let padding: Int = Int(paddings[i])
// let stride: Int = Int(strides[i]) let stride: Int = Int(strides[i])
// let dKernel = dilation * (filterSize - 1) + 1 let dKernel = dilation * (filterSize - 1) + 1
// let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1 let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
// outDim.append(outputSize) outDim.append(outputSize)
// } }
// outDim.append(filterDim[0]) outDim.append(filterDim[0])
para.output.dim = Dim.init(inDim: outDim)
} }
typealias OpType = ConvOp<P> typealias OpType = ConvOp<P>
......
...@@ -34,7 +34,6 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferS ...@@ -34,7 +34,6 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>>, Runable, Creator, InferS
typealias OpType = FeedOp<P> typealias OpType = FeedOp<P>
func inferShape() { func inferShape() {
para.output.dim = para.input.dim
} }
func runImpl() { func runImpl() {
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import Foundation import Foundation
struct FetchParam<P: PrecisionType>: OpParam{ struct FetchParam<P: PrecisionType>: OpParam{
var input: Texture let output: ResultHolder<P>
var output: Texture let input: Texture
init(opDesc: OpDesc, scope: Scope) throws { init(opDesc: OpDesc, scope: Scope) throws {
do { do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope) input = try FetchParam.inputX(inputs: opDesc.inputs, from: scope)
...@@ -32,7 +32,7 @@ struct FetchParam<P: PrecisionType>: OpParam{ ...@@ -32,7 +32,7 @@ struct FetchParam<P: PrecisionType>: OpParam{
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{ class FetchOp<P: PrecisionType>: Operator<FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() { func inferShape() {
para.output.dim = para.input.dim print(para.input.dim)
} }
typealias OpType = FetchOp<P> typealias OpType = FetchOp<P>
......
...@@ -35,7 +35,4 @@ struct BlockDesc { ...@@ -35,7 +35,4 @@ struct BlockDesc {
self.ops = ops self.ops = ops
} }
} }
...@@ -17,8 +17,43 @@ import Foundation ...@@ -17,8 +17,43 @@ import Foundation
struct TensorDesc { struct TensorDesc {
let dims: [Int] let dims: [Int]
let dataType: VarTypeType let dataType: VarTypeType
let dataLayout: DataLayout = .NCHW
var NCHWDim: [Int] {
get {
if dims.count != 4 {
return dims
}
if dataLayout == .NCHW {
return dims
} else if dataLayout == .NHWC{
var resultDims = dims
resultDims.swapAt(1, 3)
return resultDims
} else {
fatalError(" not support other layout")
}
}
}
var NHWCDim: [Int] {
get {
if dims.count != 4 {
return dims
}
if dataLayout == .NHWC {
return dims
} else if dataLayout == .NCHW{
var resultDims = dims
resultDims.swapAt(1, 3)
return resultDims
} else {
fatalError(" not support other layout")
}
}
}
init(protoTensorDesc: PaddleMobile_Framework_Proto_VarType.TensorDesc) { init(protoTensorDesc: PaddleMobile_Framework_Proto_VarType.TensorDesc) {
dims = protoTensorDesc.dims.map{ Int($0)} dims = protoTensorDesc.dims.map{ Int($0) > 0 ? Int($0) : 1 }
dataType = VarTypeType.init(rawValue: protoTensorDesc.dataType.rawValue) ?? .ErrorType dataType = VarTypeType.init(rawValue: protoTensorDesc.dataType.rawValue) ?? .ErrorType
} }
......
...@@ -44,3 +44,9 @@ public struct Dim { ...@@ -44,3 +44,9 @@ public struct Dim {
fatalError() fatalError()
} }
} }
extension Dim: CustomStringConvertible {
public var description: String {
return "\(dims)"
}
}
...@@ -14,38 +14,124 @@ ...@@ -14,38 +14,124 @@
import Foundation import Foundation
protocol Tensorial { protocol Tensorial: CustomStringConvertible, CustomDebugStringConvertible{
var dim: Dim { get set } var dim: Dim { get set }
func numel() -> Int func numel() -> Int
func dataLayout() -> DataLayout var layout: DataLayout { get }
init(inDim: Dim, inLayout: DataLayout)
} }
extension Tensorial { extension Tensorial {
func numel() -> Int { func numel() -> Int {
return dim.numel() return dim.numel()
} }
} }
class Tensor <P: PrecisionType>: Tensorial { class Tensor<P: PrecisionType>: Tensorial {
var dim: Dim { var data: Data
get { var dim: Dim
return paraData.dim private(set) var layout: DataLayout
class Data {
init(inSize: Int, inPointer: UnsafeMutablePointer<P>) {
size = inSize
pointer = inPointer
}
let size: Int
var pointer: UnsafeMutablePointer<P>
subscript(index: Int) -> P{
get {
return pointer[index]
}
set {
pointer[index] = newValue
}
} }
set { func release() {
paraData.dim = newValue pointer.deinitialize(count: size)
pointer.deallocate()
}
deinit {
release()
} }
} }
required init(inDim: Dim, inLayout: DataLayout = .NCHW) {
dim = inDim
let size = inDim.numel() * MemoryLayout<P>.size
let pointer = UnsafeMutablePointer<P>.allocate(capacity: size)
data = Data.init(inSize: size, inPointer: pointer)
layout = inLayout
}
let paraData: ParamData<P> func convert(to: DataLayout) {
init(inDimArray: [Int], inData: ParamData<P>) { guard to != layout else {
paraData = inData return
}
guard dim.cout() == 4 else {
return
}
guard layout == .NCHW && to == .NHWC else {
// other not support
return
}
let newPointer = UnsafeMutablePointer<P>.allocate(capacity: data.size)
if layout == .NCHW {
NCHW2NHWC(newPtr: newPointer)
}
data.release()
data.pointer = newPointer
layout = to
}
func NCHW2NHWC(newPtr: UnsafeMutablePointer<P>) {
let N = dim[0]
let C = dim[1]
let H = dim[2]
let W = dim[3]
let HXW = H * W
let CXHXW = C * H * W
var index: Int = 0
for n in 0..<N {
for h in 0..<H{
for w in 0..<W{
for c in 0..<C{
newPtr[index] = data.pointer[n * CXHXW + c * HXW + h * w + w]
index += 1
}
}
}
}
dim.swapeDimAt(index1: 1, index2: 3)
} }
init(inData: ParamData<P>) { }
paraData = inData
extension Tensor {
var debugDescription: String {
var str = ""
str += "Dim: \(dim) \n value:[ "
if data.size < 20 {
for d in 0..<data.size {
str += " \(data[d]) "
}
} else {
for d in stride(from: 0, to: data.size, by: data.size/20) {
str += " \(data[d]) "
}
}
str += " ]"
return str
} }
func dataLayout() -> DataLayout { var description: String {
return paraData.layout return debugDescription
} }
} }
...@@ -17,20 +17,39 @@ import Foundation ...@@ -17,20 +17,39 @@ import Foundation
public class Texture: Tensorial { public class Texture: Tensorial {
var dim: Dim var dim: Dim
// let texture: MTLTexture
func dataLayout() -> DataLayout { required public init(inDim: Dim, inLayout: DataLayout = .NHWC) {
return .NHWC dim = inDim
layout = inLayout
} }
private(set) var layout: DataLayout
// let texture: MTLTexture
public init(inTexture: MTLTexture, inDim: Dim) { public init(inTexture: MTLTexture, inDim: Dim) {
// texture = inTexture // texture = inTexture
dim = inDim dim = inDim
layout = .NHWC
} }
public init() { public init(inLayout: DataLayout = .NHWC) {
dim = Dim.init(inDim: []) dim = Dim.init(inDim: [])
layout = inLayout
// fatalError() }
}
extension Texture {
public var description: String {
return debugDescription
}
public var debugDescription: String{
var str = ""
str += "Dim: \(dim) \n value:[ "
str += " ]"
return str
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册