提交 2694b34b 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #535 from codeWorm2015/metal

add conv add batch norm relu metal
......@@ -214,10 +214,8 @@
FC0E2DB420EDC03C009C1FAC /* conv2d_27.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */; };
FC0E2DB520EDC03C009C1FAC /* conv2d_33.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */; };
FC0E2DB620EDC03C009C1FAC /* depthwise_conv2d_7.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */; };
FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F820F227C60099DBAF /* mobilenet */; };
FCEBC0FD20F227C60099DBAF /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F920F227C60099DBAF /* params */; };
FCEBC0FE20F227C60099DBAF /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FA20F227C60099DBAF /* model */; };
FCEBC0FF20F227C60099DBAF /* yolo in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FB20F227C60099DBAF /* yolo */; };
FCD04E6320F3146B0007374F /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6120F3146A0007374F /* params */; };
FCD04E6420F3146B0007374F /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCD04E6220F3146A0007374F /* model */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
/* End PBXBuildFile section */
......@@ -448,10 +446,8 @@
FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_27.w_0; sourceTree = "<group>"; };
FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_33.w_0; sourceTree = "<group>"; };
FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = depthwise_conv2d_7.w_0; sourceTree = "<group>"; };
FCEBC0F820F227C60099DBAF /* mobilenet */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = mobilenet; sourceTree = "<group>"; };
FCEBC0F920F227C60099DBAF /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FCEBC0FA20F227C60099DBAF /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FCEBC0FB20F227C60099DBAF /* yolo */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = yolo; sourceTree = "<group>"; };
FCD04E6120F3146A0007374F /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FCD04E6220F3146A0007374F /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
......@@ -531,7 +527,7 @@
FC0E2C2020EDC03B009C1FAC /* models */ = {
isa = PBXGroup;
children = (
FCEBC0F720F227C60099DBAF /* yolo */,
FCD04E6020F3146A0007374F /* mobilenet */,
FC0E2C2420EDC03B009C1FAC /* mobilenetssd */,
);
name = models;
......@@ -745,15 +741,13 @@
path = mobilenetssd;
sourceTree = "<group>";
};
FCEBC0F720F227C60099DBAF /* yolo */ = {
FCD04E6020F3146A0007374F /* mobilenet */ = {
isa = PBXGroup;
children = (
FCEBC0F820F227C60099DBAF /* mobilenet */,
FCEBC0F920F227C60099DBAF /* params */,
FCEBC0FA20F227C60099DBAF /* model */,
FCEBC0FB20F227C60099DBAF /* yolo */,
FCD04E6120F3146A0007374F /* params */,
FCD04E6220F3146A0007374F /* model */,
);
path = yolo;
path = mobilenet;
sourceTree = "<group>";
};
/* End PBXGroup section */
......@@ -828,7 +822,6 @@
FC0E2D2920EDC03B009C1FAC /* batch_norm_2.b_0 in Resources */,
FC0E2DA920EDC03C009C1FAC /* conv2d_26.w_0 in Resources */,
FC0E2D0420EDC03B009C1FAC /* batch_norm_16.w_2 in Resources */,
FCEBC0FE20F227C60099DBAF /* model in Resources */,
FC0E2D0720EDC03B009C1FAC /* batch_norm_6.w_1 in Resources */,
FC0E2DB020EDC03C009C1FAC /* batch_norm_30.w_2 in Resources */,
FC0E2D9720EDC03C009C1FAC /* conv2d_25.w_0 in Resources */,
......@@ -844,11 +837,9 @@
FC0E2DA620EDC03C009C1FAC /* depthwise_conv2d_4.w_0 in Resources */,
FC0E2D6920EDC03C009C1FAC /* conv2d_6.w_0 in Resources */,
FC0E2D6520EDC03C009C1FAC /* conv2d_7.w_0 in Resources */,
FCEBC0FD20F227C60099DBAF /* params in Resources */,
FC0E2DAB20EDC03C009C1FAC /* batch_norm_19.w_2 in Resources */,
FC0E2D9920EDC03C009C1FAC /* conv2d_31.w_0 in Resources */,
FC0E2D3020EDC03B009C1FAC /* batch_norm_34.w_0 in Resources */,
FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */,
FC0E2D1220EDC03B009C1FAC /* batch_norm_34.b_0 in Resources */,
FC0E2D4D20EDC03C009C1FAC /* batch_norm_7.b_0 in Resources */,
FC0E2D2520EDC03B009C1FAC /* batch_norm_21.w_1 in Resources */,
......@@ -857,6 +848,7 @@
FC0E2D8620EDC03C009C1FAC /* conv2d_23.w_0 in Resources */,
FC0E2CFE20EDC03B009C1FAC /* depthwise_conv2d_9.w_0 in Resources */,
FC0E2D4C20EDC03C009C1FAC /* batch_norm_8.w_2 in Resources */,
FCD04E6320F3146B0007374F /* params in Resources */,
FC0E2D5820EDC03C009C1FAC /* conv2d_5.w_0 in Resources */,
FC0E2D1620EDC03B009C1FAC /* batch_norm_3.w_1 in Resources */,
FC0E2DB120EDC03C009C1FAC /* batch_norm_24.w_2 in Resources */,
......@@ -949,7 +941,6 @@
FC0E2D0F20EDC03B009C1FAC /* batch_norm_5.w_0 in Resources */,
FC0E2D4520EDC03C009C1FAC /* batch_norm_9.w_2 in Resources */,
FC0E2D9020EDC03C009C1FAC /* batch_norm_23.w_2 in Resources */,
FCEBC0FF20F227C60099DBAF /* yolo in Resources */,
FC0E2D6720EDC03C009C1FAC /* conv2d_31.b_0 in Resources */,
FC0E2DA020EDC03C009C1FAC /* conv2d_18.w_0 in Resources */,
FC0E2D1C20EDC03B009C1FAC /* conv2d_13.w_0 in Resources */,
......@@ -972,6 +963,7 @@
FC0E2D0920EDC03B009C1FAC /* conv2d_14.w_0 in Resources */,
FC0E2CF720EDC03B009C1FAC /* batch_norm_28.w_2 in Resources */,
FC0E2D9520EDC03C009C1FAC /* depthwise_conv2d_5.w_0 in Resources */,
FCD04E6420F3146B0007374F /* model in Resources */,
FC0E2D4A20EDC03C009C1FAC /* conv2d_9.w_0 in Resources */,
FC0E2D4E20EDC03C009C1FAC /* batch_norm_19.w_1 in Resources */,
FC0E2D3620EDC03C009C1FAC /* batch_norm_18.w_0 in Resources */,
......
......@@ -45,6 +45,14 @@
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
FCD04E6620F314C50007374F /* PoolOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6520F314C50007374F /* PoolOp.swift */; };
FCD04E6820F315020007374F /* PoolKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6720F315020007374F /* PoolKernel.swift */; };
FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6920F319EC0007374F /* SoftmaxOp.swift */; };
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */; };
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */; };
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E6F20F31B720007374F /* ReshapeKernel.swift */; };
FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7120F343420007374F /* ConvAddOp.swift */; };
FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCD04E7320F3437E0007374F /* ConvAddKernel.swift */; };
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */; };
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; };
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; };
......@@ -93,6 +101,14 @@
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
FCD04E6520F314C50007374F /* PoolOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PoolOp.swift; sourceTree = "<group>"; };
FCD04E6720F315020007374F /* PoolKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PoolKernel.swift; sourceTree = "<group>"; };
FCD04E6920F319EC0007374F /* SoftmaxOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SoftmaxOp.swift; sourceTree = "<group>"; };
FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SoftmaxKernel.swift; sourceTree = "<group>"; };
FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReshapeOp.swift; sourceTree = "<group>"; };
FCD04E6F20F31B720007374F /* ReshapeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ReshapeKernel.swift; sourceTree = "<group>"; };
FCD04E7120F343420007374F /* ConvAddOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddOp.swift; sourceTree = "<group>"; };
FCD04E7320F3437E0007374F /* ConvAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddKernel.swift; sourceTree = "<group>"; };
FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = ConvAddBatchNormReluOp.swift; path = "paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"; sourceTree = SOURCE_ROOT; };
FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = "<group>"; };
FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; };
......@@ -193,6 +209,10 @@
FC039BA820E11CBC0081E9F8 /* ReluOp.swift */,
FC9D037F20E22FBB000F735A /* FeedOp.swift */,
FC9D038120E2312E000F735A /* FetchOp.swift */,
FCD04E6520F314C50007374F /* PoolOp.swift */,
FCD04E6920F319EC0007374F /* SoftmaxOp.swift */,
FCD04E6D20F31B4B0007374F /* ReshapeOp.swift */,
FCD04E7120F343420007374F /* ConvAddOp.swift */,
);
path = Operators;
sourceTree = "<group>";
......@@ -227,6 +247,10 @@
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */,
FCD04E6720F315020007374F /* PoolKernel.swift */,
FCD04E6B20F31A280007374F /* SoftmaxKernel.swift */,
FCD04E6F20F31B720007374F /* ReshapeKernel.swift */,
FCD04E7320F3437E0007374F /* ConvAddKernel.swift */,
);
path = Kernels;
sourceTree = "<group>";
......@@ -346,6 +370,8 @@
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
FC039B9B20E11CA00081E9F8 /* Executor.swift in Sources */,
FCD04E7020F31B720007374F /* ReshapeKernel.swift in Sources */,
FCD04E7220F343420007374F /* ConvAddOp.swift in Sources */,
FC039BBB20E11CC20081E9F8 /* ProgramDesc.swift in Sources */,
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
......@@ -362,21 +388,27 @@
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
FCD04E6E20F31B4B0007374F /* ReshapeOp.swift in Sources */,
FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */,
FC039BBF20E11CC20081E9F8 /* Attribute.swift in Sources */,
FCD04E7420F3437E0007374F /* ConvAddKernel.swift in Sources */,
FC039BB920E11CC20081E9F8 /* Scope.swift in Sources */,
FCD04E6620F314C50007374F /* PoolOp.swift in Sources */,
FC039BAC20E11CBC0081E9F8 /* BatchNormOp.swift in Sources */,
FC039BBC20E11CC20081E9F8 /* VarDesc.swift in Sources */,
FC0E2DBA20EE3B8D009C1FAC /* ReluKernel.swift in Sources */,
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */,
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
FCD04E6820F315020007374F /* PoolKernel.swift in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
FC039BBE20E11CC20081E9F8 /* OpDesc.swift in Sources */,
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */,
......
......@@ -72,6 +72,16 @@ extension Array: CIntIndex{
}
}
extension Array where Element: AnyObject{
mutating func remove(element: Element) {
if let index = index(where: { (node) -> Bool in
return unsafeBitCast(element, to: Int.self) == unsafeBitCast(node, to: Int.self)
}) {
remove(at: index)
}
}
}
//MARK: Array extension
extension Array where Element: Comparable{
......@@ -92,4 +102,10 @@ extension String{
}
}
func address<T: AnyObject>(o: T) -> String {
return String.init(format: "%018p", unsafeBitCast(o, to: Int.self))
}
......@@ -42,7 +42,6 @@ extension MTLDevice {
}
}
func pipeLine(funcName: String, inPaddleMobileLib: Bool = true) -> MTLComputePipelineState {
let useLib = inPaddleMobileLib ? paddleMobileLibrary() : defaultLibrary()
guard let function = useLib.makeFunction(name: funcName) else {
......@@ -65,7 +64,7 @@ extension MTLComputeCommandEncoder {
let width = computePipline.threadExecutionWidth
let height = computePipline.maxTotalThreadsPerThreadgroup/width
let threadsPerGroup = MTLSize.init(width: width, height: height, depth: 1)
// print(" thread: threads per group: \(threadsPerGroup) ")
// print(" thread: out texture width: \(outTexture.width) , out texture height: \(outTexture.height)")
......
......@@ -14,22 +14,71 @@
import Foundation
public protocol SummableMultipliable: Equatable {
static func +(lhs: Self, rhs: Self) -> Self
static func *(lhs: Self, rhs: Self) -> Self
static func -(lhs: Self, rhs: Self) -> Self
}
public protocol PrecisionType: SummableMultipliable{
init(inFloat: Float32)
init(inFloat16: Float16)
init<P: PrecisionType>(_ inP: P)
static var bitSize: UInt { get }
}
public typealias Float16 = Int16
extension Float16: PrecisionType {
public static func * (prefix: Float16, postfix: Float16) {
return prefix * postfix
}
public init<P>(_ inP: P) where P : PrecisionType {
if P.bitSize == Float32.bitSize {
self = Float16(inFloat: inP as! Float32)
} else if P.bitSize == Float16.bitSize {
self = inP as! Float16
} else {
fatalError()
}
}
public static var bitSize: UInt {
return 16
}
public init(inFloat16: Float16) {
self = inFloat16
}
public init(inFloat: Float32) {
self = Int16(inFloat)
}
}
public protocol PrecisionType {
init(inFloat: Float32)
}
extension Float32: PrecisionType {
public init<P>(_ inP: P) where P : PrecisionType {
if P.bitSize == Float32.bitSize {
self = inP as! Float32
} else if P.bitSize == Float16.bitSize {
self = Float32.init(inP as! Float16)
} else {
fatalError()
}
}
public init(inFloat: Float32) {
self = inFloat
}
public init(inFloat16: Float16) {
self = Float32.init(inFloat16)
}
public static var bitSize: UInt {
return 32
}
}
public enum DataLayout {
......
......@@ -55,7 +55,8 @@ public class Executor<P: PrecisionType> {
device = inDevice
queue = inQueue
for block in inProgram.programDesc.blocks {
for op in block.ops {
for i in 0..<7 {
let op = block.ops[i]
do {
let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
op.inferShape()
......@@ -64,6 +65,15 @@ public class Executor<P: PrecisionType> {
throw error
}
}
// for op in block.ops {
// do {
// let op = try OpCreator<P>.shared.creat(device: inDevice, opDesc: op, scope: inProgram.scope)
// op.inferShape()
// ops.append(op)
// } catch let error {
// throw error
// }
// }
}
}
......
......@@ -104,12 +104,9 @@ public class Loader<P: PrecisionType> {
serializedData: modelData)
let originProgramDesc = ProgramDesc.init(protoProgram: protoProgram)
let programDesc = ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc)
print(programDesc)
fatalError()
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
throw PaddleMobileError.loaderError(message: "load para error")
}
......@@ -180,7 +177,7 @@ public class Loader<P: PrecisionType> {
}
}
let program = Program.init(protoProgramDesc: protoProgram, inParamPath: paraPath, inScope: scope)
let program = Program.init(inProgramDesc: programDesc, inParamPath: paraPath, inScope: scope)
return program
} catch _ {
......
......@@ -40,13 +40,17 @@ class OpCreator<P: PrecisionType> {
}
let opCreators: [String : (MTLDevice, OpDesc, Scope) throws -> Runable & InferShaperable] =
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat,
gConvAddBatchNormReluType : ConvAddBatchNormReluOp<P>.creat]
[gConvType : ConvOp<P>.creat,
gBatchNormType : BatchNormOp<P>.creat,
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat,
gConvAddBatchNormReluType : ConvAddBatchNormReluOp<P>.creat,
gPooType : PoolOp<P>.creat,
gSoftmaxType : SoftmaxOp<P>.creat,
gReshapeType : ReshapeOp<P>.creat,
gConvAddType : ConvAddOp<P>.creat]
private init(){}
}
......@@ -18,9 +18,9 @@ import Foundation
protocol Fusion {
static func fusionNode() -> Node
static func change() -> [String : [(from: String, to: String)]]
static func fusionType() -> String
}
protocol Runable {
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
......@@ -117,20 +117,20 @@ let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
let gConvAddBatchNormReluType = "conv_add_batchnorm_relu"
let gPooType = "pool2d"
let gSoftmaxType = "softmax"
let gReshapeType = "reshape"
let gConvAddType = "conv_add"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"]),
gConvAddBatchNormReluType : (inputs: ["Input"], outputs: ["Out"])]
gConvAddBatchNormReluType : (inputs: ["Input"], outputs: ["Out"]),
gPooType : (inputs: ["X"], outputs: ["Out"]),
gSoftmaxType : (inputs: ["X"], outputs: ["Out"]),
gReshapeType : (inputs: ["X"], outputs: ["Out"]),
gConvAddType : (inputs: ["Input"], outputs: ["Out"])]
......@@ -14,9 +14,9 @@
import Foundation
struct BatchNormParam<P: PrecisionType>: OpParam {
class BatchNormParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
......
......@@ -8,21 +8,49 @@
import Foundation
class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKernel<P>, ConvParam<P>>, Runable, Creator, InferShaperable, Fusion{
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode
--> Node.init(inType: gElementwiseAdd)
--> Node.init(inType: gBatchNormType)
--> Node.init(inType: gReluType)
return beginNode
class ConvAddBatchNormReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
filter = try ConvAddBatchNormReluParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try ConvAddBatchNormReluParam.input(inputs: opDesc.inputs, from: inScope)
output = try ConvAddBatchNormReluParam.output(outputs: opDesc.outputs, from: inScope)
stride = try ConvAddBatchNormReluParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvAddBatchNormReluParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvAddBatchNormReluParam.getAttr(key: "dilations", attrs: opDesc.attrs)
epsilon = try ConvAddBatchNormReluParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
groups = try ConvAddBatchNormReluParam.getAttr(key: "groups", attrs: opDesc.attrs)
variance = try ConvAddBatchNormReluParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
bias = try ConvAddBatchNormReluParam.inputBiase(inputs: opDesc.paraInputs, from: inScope)
scale = try ConvAddBatchNormReluParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
mean = try ConvAddBatchNormReluParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
y = try ConvAddBatchNormReluParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch let error {
throw error
}
}
static func change() -> [String : [(from: String, to: String)]] {
return [:]
}
let input: Texture<P>
let variance: Tensor<ParamPrecisionType>
let bias: Tensor<ParamPrecisionType>
let mean: Tensor<ParamPrecisionType>
let scale: Tensor<ParamPrecisionType>
let y: Tensor<ParamPrecisionType>
let filter: Tensor<ParamPrecisionType>
let epsilon: Float32
var newScale: MTLBuffer?
var newBiase: MTLBuffer?
var output: Texture<P>
let stride: [Int32]
let paddings: [Int32]
let dilations: [Int32]
let groups: Int
}
class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKernel<P>, ConvAddBatchNormReluParam<P>>, Runable, Creator, InferShaperable, Fusion{
typealias OpType = ConvAddBatchNormReluOp<P>
func inferShape() {
......@@ -55,4 +83,20 @@ class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKer
}
}
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode
--> Node.init(inType: gElementwiseAdd)
--> Node.init(inType: gBatchNormType)
--> Node.init(inType: gReluType)
return beginNode
}
static func change() -> [String : [(from: String, to: String)]] {
return [:]
}
static func fusionType() -> String {
return gConvAddBatchNormReluType
}
}
//
// ConvAddBatchNormReluOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/8.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ConvAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
filter = try ConvAddParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try ConvAddParam.input(inputs: opDesc.inputs, from: inScope)
output = try ConvAddParam.output(outputs: opDesc.outputs, from: inScope)
stride = try ConvAddParam.getAttr(key: "strides", attrs: opDesc.attrs)
paddings = try ConvAddParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvAddParam.getAttr(key: "dilations", attrs: opDesc.attrs)
groups = try ConvAddParam.getAttr(key: "groups", attrs: opDesc.attrs)
y = try ConvAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture<P>
let y: Tensor<ParamPrecisionType>
let filter: Tensor<ParamPrecisionType>
var output: Texture<P>
let stride: [Int32]
let paddings: [Int32]
let dilations: [Int32]
let groups: Int
}
class ConvAddOp<P: PrecisionType>: Operator<ConvAddKernel<P>, ConvAddParam<P>>, Runable, Creator, InferShaperable, Fusion{
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode
--> Node.init(inType: gElementwiseAdd)
return beginNode
}
static func change() -> [String : [(from: String, to: String)]] {
return [:]
}
static func fusionType() -> String {
return gConvAddType
}
typealias OpType = ConvAddOp<P>
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
let strides = para.stride
let paddings = para.paddings
let dilations = para.dilations
var outDim = [inDims[0]]
for i in 0..<strides.count {
let dilation: Int = Int(dilations[i])
let filterSize: Int = filterDim[i + 1]
let inputSize: Int = inDims[i + 1]
let padding: Int = Int(paddings[i])
let stride: Int = Int(strides[i])
let dKernel = dilation * (filterSize - 1) + 1
let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
outDim.append(outputSize)
}
outDim.append(filterDim[0])
para.output.dim = Dim.init(inDim: outDim)
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
}
......@@ -14,9 +14,9 @@
import Foundation
struct ConvParam<P: PrecisionType>: OpParam {
class ConvParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
filter = try ConvParam.inputFilter(paraInputs: opDesc.paraInputs, from: inScope)
input = try ConvParam.input(inputs: opDesc.inputs, from: inScope)
......@@ -25,14 +25,15 @@ struct ConvParam<P: PrecisionType>: OpParam {
paddings = try ConvParam.getAttr(key: "paddings", attrs: opDesc.attrs)
dilations = try ConvParam.getAttr(key: "dilations", attrs: opDesc.attrs)
groups = try ConvParam.getAttr(key: "groups", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
let filter: Tensor<ParamPrecisionType>
var output: Texture<P>
let stride: [Int32]
let paddings: [Int32]
let dilations: [Int32]
......
......@@ -14,12 +14,13 @@
import Foundation
struct ElementwiseAddParam<P: PrecisionType>: OpParam {
class ElementwiseAddParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ElementwiseAddParam.inputX(inputs: opDesc.inputs, from: inScope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.inputs, from: inScope)
inputY = try ElementwiseAddParam.inputY(inputs: opDesc.paraInputs, from: inScope)
output = try ElementwiseAddParam.outputOut(outputs: opDesc.outputs, from: inScope)
axis = try ElementwiseAddParam.getAttr(key: "axis", attrs: opDesc.attrs)
} catch let error {
......
......@@ -14,14 +14,14 @@
import Foundation
struct FeedParam<P: PrecisionType>: OpParam{
class FeedParam<P: PrecisionType>: OpParam{
var output: Texture<P>
var input: InputTexture {
return scope.input() as! InputTexture
}
let scope: Scope
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do {
output = try FeedParam.outputOut(outputs: opDesc.outputs, from: inScope)
......
......@@ -14,11 +14,11 @@
import Foundation
struct FetchParam<P: PrecisionType>: OpParam{
class FetchParam<P: PrecisionType>: OpParam{
var output: ResultHolder<P> = ResultHolder.init(inDim: [], inResult: [])
let input: Texture<P>
let scope: Scope
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
scope = inScope
do {
input = try FetchParam.inputX(inputs: opDesc.inputs, from: inScope)
......
......@@ -9,11 +9,59 @@
import Foundation
class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ConvParam<P>) {
super.init(device: device, inFunctionName: "conv3x3")
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvAddBatchNormReluParam<P>) {
super.init(device: device, inFunctionName: "conv_add_batch_norm_relu_3x3")
let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
let offsetZ = 0.0
metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
var invs: [P] = []
let varianceContents = param.variance.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<param.variance.buffer.length/MemoryLayout<P>.stride {
let inv = pow(Float32.init(varianceContents[i]) + param.epsilon, 0.5)
invs.append(P(inv))
}
let newScale: UnsafeMutablePointer<P> = UnsafeMutablePointer<P>.allocate(capacity: param.scale.buffer.length)
let newBiase: UnsafeMutablePointer<P> = UnsafeMutablePointer<P>.allocate(capacity: param.bias.buffer.length)
let scaleContents = param.variance.buffer.contents().assumingMemoryBound(to: P.self)
let biaseContents = param.bias.buffer.contents().assumingMemoryBound(to: P.self)
let meanContents = param.mean.buffer.contents().assumingMemoryBound(to: P.self)
for i in 0..<param.scale.buffer.length/MemoryLayout<P>.stride {
newScale[i] = invs[i] * scaleContents[i]
newBiase[i] = biaseContents[i] - meanContents[i] * invs[i] * scaleContents[i]
}
param.newBiase = device.makeBuffer(bytes: newBiase, length: param.bias.buffer.length)
param.newScale = device.makeBuffer(bytes: newScale, length: param.scale.buffer.length)
newScale.deinitialize(count: param.scale.buffer.length)
newScale.deallocate()
newBiase.deinitialize(count: param.bias.buffer.length)
newBiase.deallocate()
}
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddBatchNormReluParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
print("ConvAddBatchNormReluKernel compute")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
encoder.setBuffer(param.bias.buffer, offset: 0, index: 2)
encoder.setBuffer(param.newScale!, offset: 0, index: 3)
encoder.setBuffer(param.newBiase!, offset: 0, index: 4)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
}
//
// ConvKernel.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ConvAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ConvAddParam<P>) {
super.init(device: device, inFunctionName: "conv3x3")
}
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
}
}
......@@ -52,3 +52,48 @@ kernel void conv3x3(texture2d_array<half, access::sample> inTexture [[texture(0)
}
outTexture.write(output, gid.xy, gid.z);
}
kernel void conv_add_batch_norm_relu_3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
const device half4 *biase [[buffer(2)]],
const device half4 *new_scale [[buffer(3)]],
const device half4 *new_biase [[buffer(4)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
return;
}
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint wightSliceCount = 36;
uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
half4 output = 0.0;
for (uint i = 0; i < inTexture.get_array_size(); ++i) {
half4 input[9];
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight = weights[weithTo + wightSliceCount * i + j * 4];
output += dot(input[j], weight);
}
}
output = fmax((output + biase[gid.z]) * new_scale[gid.z] + new_biase[gid.z], 0.0h);
outTexture.write(output, gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class PoolKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: PoolParam<P>) throws {
}
required init(device: MTLDevice, param: PoolParam<P>) {
super.init(device: device, inFunctionName: "relu")
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ReshapeKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: ReshapeParam<P>) {
super.init(device: device, inFunctionName: "relu")
}
func compute(commandBuffer: MTLCommandBuffer, param: ReshapeParam<P>) throws {
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class SoftmaxKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: SoftmaxParam<P>) throws {
}
required init(device: MTLDevice, param: SoftmaxParam<P>) {
super.init(device: device, inFunctionName: "relu")
}
}
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class PoolParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try PoolParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try PoolParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
}
class PoolOp<P: PrecisionType>: Operator<PoolKernel<P>, PoolParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = PoolOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
}
......@@ -14,9 +14,9 @@
import Foundation
struct ReluParam<P: PrecisionType>: OpParam {
class ReluParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
init(opDesc: OpDesc, inScope: Scope) throws {
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ReluParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ReluParam.outputOut(outputs: opDesc.outputs, from: inScope)
......
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class ReshapeParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try ReshapeParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
}
class ReshapeOp<P: PrecisionType>: Operator<ReshapeKernel<P>, ReshapeParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = ReshapeOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
}
//
// PoolOp.swift
// paddle-mobile
//
// Created by liuRuiLong on 2018/7/9.
// Copyright © 2018年 orange. All rights reserved.
//
import Foundation
class SoftmaxParam<P: PrecisionType>: OpParam {
typealias ParamPrecisionType = P
required init(opDesc: OpDesc, inScope: Scope) throws {
do {
input = try SoftmaxParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try SoftmaxParam.outputOut(outputs: opDesc.outputs, from: inScope)
} catch let error {
throw error
}
}
let input: Texture<P>
var output: Texture<P>
}
class SoftmaxOp<P: PrecisionType>: Operator<SoftmaxKernel<P>, SoftmaxParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
}
typealias OpType = SoftmaxOp<P>
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
}
......@@ -20,7 +20,7 @@ struct OpDesc {
let outputs: [String : [String]]
let unusedOutputs: [String : [String]]
var attrs: [String : Attr] = [:]
let type: String
var type: String
init(protoOpDesc: PaddleMobile_Framework_Proto_OpDesc) {
type = protoOpDesc.type
let creator = { (vars: [PaddleMobile_Framework_Proto_OpDesc.Var], canAdd: (String) -> Bool) -> [String : [String]] in
......
......@@ -18,8 +18,8 @@ public struct Program {
let paramPath: String
let programDesc: ProgramDesc
let scope: Scope
init(protoProgramDesc: PaddleMobile_Framework_Proto_ProgramDesc, inParamPath: String, inScope: Scope) {
programDesc = ProgramDesc.init(protoProgram: protoProgramDesc)
init(inProgramDesc: ProgramDesc, inParamPath: String, inScope: Scope) {
programDesc = inProgramDesc
paramPath = inParamPath
scope = inScope
}
......
......@@ -18,7 +18,7 @@ infix operator --> : ChainNode
class Node {
var inputs: [Node] = []
var outputs: [Node] = []
let type: String
var type: String
var opDesc: OpDesc?
init(inOpDesc: OpDesc) {
type = inOpDesc.type
......@@ -36,11 +36,12 @@ class Node {
}
func depth(begin: UInt = 1) -> UInt {
var beginMax: UInt = 0
var beginMax: UInt = 1
for output in outputs {
let subDepth = output.depth(begin: begin + 1)
beginMax = max(begin, subDepth)
}
beginMax = max(begin, beginMax)
return beginMax
}
......@@ -50,23 +51,26 @@ class Node {
return beginNode
}
func folderWith(fusion: Fusion.Type) {
func folderWith(fusion: Fusion.Type, removedNodes: inout [Node]) {
let fusionNode = fusion.fusionNode()
let change = fusion.change()
let inOutputs = outputs
outputs.removeAll()
for i in 0..<inOutputs.count {
inOutputs[i].folderWith(beginNode: self, matchNode: fusionNode.outputs[i], change: change)
inOutputs[i].folderWith(beginNode: self, matchNode: fusionNode.outputs[i], change: change, removedNodes: &removedNodes)
}
opDesc?.type = fusion.fusionType()
type = fusion.fusionType()
}
private func folderWith(beginNode: Node, matchNode: Node, change: [String : [(from: String, to: String)]]) {
private func folderWith(beginNode: Node, matchNode: Node, change: [String : [(from: String, to: String)]], removedNodes: inout [Node]) {
guard let inOpdesc = opDesc else {
fatalError()
}
for attr in inOpdesc.attrs {
beginNode.opDesc?.attrs[attr.key] = attr.value
// print(beginNode.opDesc?.attrs)
}
for paraInput in inOpdesc.paraInputs {
......@@ -86,6 +90,11 @@ class Node {
if matchNode.outputs.count == 0 {
beginNode.outputs.append(contentsOf: outputs)
}
removedNodes.append(self)
for i in 0..<matchNode.outputs.count {
outputs[i].folderWith(beginNode: beginNode, matchNode: matchNode.outputs[i], change: change, removedNodes: &removedNodes)
}
}
......@@ -122,11 +131,10 @@ extension Node: Equatable {
return true
}
}
class ProgramOptimize<P: PrecisionType> {
let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp<P>.self]
let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp<P>.self, ConvAddOp<P>.self]
func optimize(originProgramDesc: ProgramDesc) -> ProgramDesc {
guard originProgramDesc.blocks.count == 1 else {
......@@ -141,7 +149,7 @@ class ProgramOptimize<P: PrecisionType> {
guard let opInputKeys = opInfos[opDesc.type]?.inputs, let outputKeys = opInfos[opDesc.type]?.outputs else {
fatalError()
}
let node = Node.init(inOpDesc: opDesc)
for inputKey in opInputKeys {
if let inputs = opDesc.inputs[inputKey] {
......@@ -164,28 +172,32 @@ class ProgramOptimize<P: PrecisionType> {
nodes.append(node)
if var nodes = typeMapNodes[opDesc.type] {
nodes.append(node)
typeMapNodes[opDesc.type] = nodes
if var inNodes = typeMapNodes[opDesc.type] {
inNodes.append(node)
typeMapNodes[opDesc.type] = inNodes
} else {
typeMapNodes[opDesc.type] = []
typeMapNodes[opDesc.type] = [node]
}
}
for fusion in fusionOps {
let fusionNode = fusion.fusionNode()
let depth = fusionNode.depth()
print(depth)
if let nodes = typeMapNodes[fusionNode.type] {
for node in nodes {
let toNode = node.to(depth: 4)
if let toMatchNodes = typeMapNodes[fusionNode.type] {
for node in toMatchNodes {
let toNode = node.to(depth: depth)
if toNode == fusionNode { // match
node.folderWith(fusion: fusion)
var removeNodes: [Node] = []
node.folderWith(fusion: fusion, removedNodes: &removeNodes)
for removeNode in removeNodes {
nodes.remove(element: removeNode)
}
}
}
}
}
var ops: [OpDesc] = []
for node in nodes {
ops.append(node.opDesc!)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册