未验证 提交 0a24e8c9 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #527 from codeWorm2015/metal

add program optimize
......@@ -14,8 +14,6 @@
FC039B8920E11C560081E9F8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8820E11C560081E9F8 /* Assets.xcassets */; };
FC039B8C20E11C560081E9F8 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = FC039B8A20E11C560081E9F8 /* LaunchScreen.storyboard */; };
FC0E2C1F20EDC030009C1FAC /* apple.jpg in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C1E20EDC030009C1FAC /* apple.jpg */; };
FC0E2CED20EDC03B009C1FAC /* params in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C2220EDC03B009C1FAC /* params */; };
FC0E2CEE20EDC03B009C1FAC /* model in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C2320EDC03B009C1FAC /* model */; };
FC0E2CEF20EDC03B009C1FAC /* batch_norm_7.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C2520EDC03B009C1FAC /* batch_norm_7.w_0 */; };
FC0E2CF020EDC03B009C1FAC /* batch_norm_26.b_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C2620EDC03B009C1FAC /* batch_norm_26.b_0 */; };
FC0E2CF120EDC03B009C1FAC /* batch_norm_32.b_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2C2720EDC03B009C1FAC /* batch_norm_32.b_0 */; };
......@@ -216,6 +214,10 @@
FC0E2DB420EDC03C009C1FAC /* conv2d_27.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */; };
FC0E2DB520EDC03C009C1FAC /* conv2d_33.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */; };
FC0E2DB620EDC03C009C1FAC /* depthwise_conv2d_7.w_0 in Resources */ = {isa = PBXBuildFile; fileRef = FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */; };
FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F820F227C60099DBAF /* mobilenet */; };
FCEBC0FD20F227C60099DBAF /* params in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0F920F227C60099DBAF /* params */; };
FCEBC0FE20F227C60099DBAF /* model in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FA20F227C60099DBAF /* model */; };
FCEBC0FF20F227C60099DBAF /* yolo in Resources */ = {isa = PBXBuildFile; fileRef = FCEBC0FB20F227C60099DBAF /* yolo */; };
FCEBEC2C20E1391F00C0B14D /* paddle_mobile.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; };
FCEBEC2D20E1391F00C0B14D /* paddle_mobile.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
/* End PBXBuildFile section */
......@@ -246,8 +248,6 @@
FC039B8B20E11C560081E9F8 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = "<group>"; };
FC039B8D20E11C560081E9F8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
FC0E2C1E20EDC030009C1FAC /* apple.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = apple.jpg; sourceTree = "<group>"; };
FC0E2C2220EDC03B009C1FAC /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FC0E2C2320EDC03B009C1FAC /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FC0E2C2520EDC03B009C1FAC /* batch_norm_7.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = batch_norm_7.w_0; sourceTree = "<group>"; };
FC0E2C2620EDC03B009C1FAC /* batch_norm_26.b_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = batch_norm_26.b_0; sourceTree = "<group>"; };
FC0E2C2720EDC03B009C1FAC /* batch_norm_32.b_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = batch_norm_32.b_0; sourceTree = "<group>"; };
......@@ -448,6 +448,10 @@
FC0E2CEA20EDC03B009C1FAC /* conv2d_27.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_27.w_0; sourceTree = "<group>"; };
FC0E2CEB20EDC03B009C1FAC /* conv2d_33.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = conv2d_33.w_0; sourceTree = "<group>"; };
FC0E2CEC20EDC03B009C1FAC /* depthwise_conv2d_7.w_0 */ = {isa = PBXFileReference; lastKnownFileType = file; path = depthwise_conv2d_7.w_0; sourceTree = "<group>"; };
FCEBC0F820F227C60099DBAF /* mobilenet */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = mobilenet; sourceTree = "<group>"; };
FCEBC0F920F227C60099DBAF /* params */ = {isa = PBXFileReference; lastKnownFileType = file; path = params; sourceTree = "<group>"; };
FCEBC0FA20F227C60099DBAF /* model */ = {isa = PBXFileReference; lastKnownFileType = file; path = model; sourceTree = "<group>"; };
FCEBC0FB20F227C60099DBAF /* yolo */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = yolo; sourceTree = "<group>"; };
FCEBEC2B20E1391F00C0B14D /* paddle_mobile.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; path = paddle_mobile.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
......@@ -527,22 +531,13 @@
FC0E2C2020EDC03B009C1FAC /* models */ = {
isa = PBXGroup;
children = (
FC0E2C2120EDC03B009C1FAC /* yolo */,
FCEBC0F720F227C60099DBAF /* yolo */,
FC0E2C2420EDC03B009C1FAC /* mobilenetssd */,
name = models;
path = ../../models;
sourceTree = "<group>";
FC0E2C2120EDC03B009C1FAC /* yolo */ = {
isa = PBXGroup;
children = (
FC0E2C2220EDC03B009C1FAC /* params */,
FC0E2C2320EDC03B009C1FAC /* model */,
path = yolo;
sourceTree = "<group>";
FC0E2C2420EDC03B009C1FAC /* mobilenetssd */ = {
isa = PBXGroup;
children = (
......@@ -750,6 +745,17 @@
path = mobilenetssd;
sourceTree = "<group>";
FCEBC0F720F227C60099DBAF /* yolo */ = {
isa = PBXGroup;
children = (
FCEBC0F820F227C60099DBAF /* mobilenet */,
FCEBC0F920F227C60099DBAF /* params */,
FCEBC0FA20F227C60099DBAF /* model */,
FCEBC0FB20F227C60099DBAF /* yolo */,
path = yolo;
sourceTree = "<group>";
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
......@@ -811,7 +817,6 @@
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
FC0E2CEE20EDC03B009C1FAC /* model in Resources */,
FC0E2D8520EDC03C009C1FAC /* batch_norm_13.b_0 in Resources */,
FC0E2D2020EDC03B009C1FAC /* batch_norm_0.w_1 in Resources */,
FC0E2D6F20EDC03C009C1FAC /* batch_norm_12.w_1 in Resources */,
......@@ -821,9 +826,9 @@
FC0E2DA420EDC03C009C1FAC /* batch_norm_28.b_0 in Resources */,
FC0E2D9F20EDC03C009C1FAC /* batch_norm_33.w_2 in Resources */,
FC0E2D2920EDC03B009C1FAC /* batch_norm_2.b_0 in Resources */,
FC0E2CED20EDC03B009C1FAC /* params in Resources */,
FC0E2DA920EDC03C009C1FAC /* conv2d_26.w_0 in Resources */,
FC0E2D0420EDC03B009C1FAC /* batch_norm_16.w_2 in Resources */,
FCEBC0FE20F227C60099DBAF /* model in Resources */,
FC0E2D0720EDC03B009C1FAC /* batch_norm_6.w_1 in Resources */,
FC0E2DB020EDC03C009C1FAC /* batch_norm_30.w_2 in Resources */,
FC0E2D9720EDC03C009C1FAC /* conv2d_25.w_0 in Resources */,
......@@ -839,9 +844,11 @@
FC0E2DA620EDC03C009C1FAC /* depthwise_conv2d_4.w_0 in Resources */,
FC0E2D6920EDC03C009C1FAC /* conv2d_6.w_0 in Resources */,
FC0E2D6520EDC03C009C1FAC /* conv2d_7.w_0 in Resources */,
FCEBC0FD20F227C60099DBAF /* params in Resources */,
FC0E2DAB20EDC03C009C1FAC /* batch_norm_19.w_2 in Resources */,
FC0E2D9920EDC03C009C1FAC /* conv2d_31.w_0 in Resources */,
FC0E2D3020EDC03B009C1FAC /* batch_norm_34.w_0 in Resources */,
FCEBC0FC20F227C60099DBAF /* mobilenet in Resources */,
FC0E2D1220EDC03B009C1FAC /* batch_norm_34.b_0 in Resources */,
FC0E2D4D20EDC03C009C1FAC /* batch_norm_7.b_0 in Resources */,
FC0E2D2520EDC03B009C1FAC /* batch_norm_21.w_1 in Resources */,
......@@ -942,6 +949,7 @@
FC0E2D0F20EDC03B009C1FAC /* batch_norm_5.w_0 in Resources */,
FC0E2D4520EDC03C009C1FAC /* batch_norm_9.w_2 in Resources */,
FC0E2D9020EDC03C009C1FAC /* batch_norm_23.w_2 in Resources */,
FCEBC0FF20F227C60099DBAF /* yolo in Resources */,
FC0E2D6720EDC03C009C1FAC /* conv2d_31.b_0 in Resources */,
FC0E2DA020EDC03C009C1FAC /* conv2d_18.w_0 in Resources */,
FC0E2D1C20EDC03B009C1FAC /* conv2d_13.w_0 in Resources */,
......@@ -36,6 +36,8 @@
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */; };
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC1B16B220EC9A4F00678B91 /* Kernels.metal */; };
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC1B186520ECF1C600678B91 /* ResizeKernel.swift */; };
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74820F0B954007C0C6D /* ConvKernel.metal */; };
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */; };
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */; };
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC60DB8820E9AAA500FF203F /* MetalExtension.swift */; };
FC82735920E3C04200BE430A /* OpCreator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC82735820E3C04200BE430A /* OpCreator.swift */; };
......@@ -43,6 +45,8 @@
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D037F20E22FBB000F735A /* FeedOp.swift */; };
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038120E2312E000F735A /* FetchOp.swift */; };
FC9D038420E23B01000F735A /* Texture.swift in Sources */ = {isa = PBXBuildFile; fileRef = FC9D038320E23B01000F735A /* Texture.swift */; };
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */; };
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */; };
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = FCF2D73720E64E70007AC5F5 /* Kernel.swift */; };
/* End PBXBuildFile section */
......@@ -80,6 +84,8 @@
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ElementwiseAddKernel.swift; sourceTree = "<group>"; };
FC1B16B220EC9A4F00678B91 /* Kernels.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = Kernels.metal; sourceTree = "<group>"; };
FC1B186520ECF1C600678B91 /* ResizeKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ResizeKernel.swift; sourceTree = "<group>"; };
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvKernel.metal; sourceTree = "<group>"; };
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProgramOptimize.swift; sourceTree = "<group>"; };
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture2DTo2DArrayKernel.swift; sourceTree = "<group>"; };
FC60DB8820E9AAA500FF203F /* MetalExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MetalExtension.swift; sourceTree = "<group>"; };
FC82735820E3C04200BE430A /* OpCreator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpCreator.swift; sourceTree = "<group>"; };
......@@ -87,6 +93,8 @@
FC9D037F20E22FBB000F735A /* FeedOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FeedOp.swift; sourceTree = "<group>"; };
FC9D038120E2312E000F735A /* FetchOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FetchOp.swift; sourceTree = "<group>"; };
FC9D038320E23B01000F735A /* Texture.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Texture.swift; sourceTree = "<group>"; };
FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = ConvAddBatchNormReluOp.swift; path = "paddle-mobile/Operators/ConvAddBatchNormReluOp.swift"; sourceTree = SOURCE_ROOT; };
FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddBatchNormReluKernel.swift; sourceTree = "<group>"; };
FCF2D73720E64E70007AC5F5 /* Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = Kernel.swift; path = "paddle-mobile/Operators/Kernels/Kernel.swift"; sourceTree = SOURCE_ROOT; };
/* End PBXFileReference section */
......@@ -178,6 +186,7 @@
children = (
FC086BA520E67E8500D85EF7 /* Kernels */,
FCD592FA20E248EC00252966 /* Base */,
FCEBC0F320F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift */,
FC039BA420E11CBC0081E9F8 /* ConvOp.swift */,
FC039BA520E11CBC0081E9F8 /* ElementwiseAddOp.swift */,
FC039BA720E11CBC0081E9F8 /* BatchNormOp.swift */,
......@@ -200,6 +209,7 @@
FC039BB520E11CC20081E9F8 /* OpDesc.swift */,
FC039BB620E11CC20081E9F8 /* Attribute.swift */,
FC039BB720E11CC20081E9F8 /* BlockDesc.swift */,
FC4CB74A20F12C30007C0C6D /* ProgramOptimize.swift */,
path = Program;
sourceTree = "<group>";
......@@ -215,6 +225,8 @@
FC0E2DBD20EE460D009C1FAC /* BatchNormKernel.swift */,
FC0E2DBF20EE461F009C1FAC /* ElementwiseAddKernel.swift */,
FC5163F520EF556E00636C28 /* Texture2DTo2DArrayKernel.swift */,
FC4CB74820F0B954007C0C6D /* ConvKernel.metal */,
FCEBC0F520F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift */,
path = Kernels;
sourceTree = "<group>";
......@@ -338,13 +350,16 @@
FC9D037920E229E4000F735A /* OpParam.swift in Sources */,
FC1B186620ECF1C600678B91 /* ResizeKernel.swift in Sources */,
FCF2D73820E64E70007AC5F5 /* Kernel.swift in Sources */,
FCEBC0F420F1FDD90099DBAF /* ConvAddBatchNormReluOp.swift in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */,
FC1B16B320EC9A4F00678B91 /* Kernels.metal in Sources */,
FC039BBA20E11CC20081E9F8 /* TensorDesc.swift in Sources */,
FC039BA020E11CB20081E9F8 /* Dim.swift in Sources */,
FC039BB820E11CC20081E9F8 /* framework.pb.swift in Sources */,
FC039B9920E11C9A0081E9F8 /* Types.swift in Sources */,
FC4CB74920F0B954007C0C6D /* ConvKernel.metal in Sources */,
FC039BA920E11CBC0081E9F8 /* ConvOp.swift in Sources */,
FC9D038420E23B01000F735A /* Texture.swift in Sources */,
FC039B9820E11C9A0081E9F8 /* Errors.swift in Sources */,
......@@ -359,6 +374,7 @@
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */,
FC039BC020E11CC20081E9F8 /* BlockDesc.swift in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
......@@ -102,8 +102,14 @@ public class Loader<P: PrecisionType> {
do {
let protoProgram = try PaddleMobile_Framework_Proto_ProgramDesc.init(
serializedData: modelData)
let programDesc = ProgramDesc.init(protoProgram: protoProgram)
let originProgramDesc = ProgramDesc.init(protoProgram: protoProgram)
let programDesc = ProgramOptimize<P>.init().optimize(originProgramDesc: originProgramDesc)
guard let paraLoader = try? ParaLoader.init(paramPath: paraPath) else {
throw PaddleMobileError.loaderError(message: "load para error")
......@@ -159,6 +165,7 @@ public class Loader<P: PrecisionType> {
throw error
tensor.convert(to: .NHWC)
tensor.initBuffer(device: device)
scope[varDesc.name] = tensor
} else {
let dim = Dim.init(inDim: tensorDesc.NHWCDim)
......@@ -45,7 +45,8 @@ class OpCreator<P: PrecisionType> {
gReluType : ReluOp<P>.creat,
gElementwiseAdd : ElementwiseAddOp<P>.creat,
gFeedType : FeedOp<P>.creat,
gFetchType : FetchOp<P>.creat]
gFetchType : FetchOp<P>.creat,
gConvAddBatchNormReluType : ConvAddBatchNormReluOp<P>.creat]
private init(){}
......@@ -15,6 +15,12 @@
import Metal
import Foundation
protocol Fusion {
static func fusionNode() -> Node
static func change() -> [String : [(from: String, to: String)]]
protocol Runable {
func run(device: MTLDevice, buffer: MTLCommandBuffer) throws
func runImpl(device: MTLDevice,buffer: MTLCommandBuffer) throws
......@@ -56,11 +62,11 @@ protocol InferShaperable {
protocol OperatorProtocol {
associatedtype ParamType: OpParam
associatedtype KerType: Computable
associatedtype ParamType
associatedtype KerType: Computable where Self.KerType.ParamType == ParamType
var type: String { get }
var inputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get }
var paraInputs: [String : [String]] { get set }
var outpus: [String : [String]] { get }
var attrs: [String : Attr] { get }
var para: ParamType { get }
......@@ -78,13 +84,12 @@ extension OperatorProtocol {
class Operator <ParameterType: OpParam, KernelType: Computable>: OperatorProtocol{
class Operator <KernelType: Computable , ParameterType>: OperatorProtocol where KernelType.ParamType == ParameterType {
typealias ParamType = ParameterType
typealias KerType = KernelType
let type: String
let inputs: [String : [String]]
let paraInputs: [String : [String]]
var paraInputs: [String : [String]]
let outpus: [String : [String]]
let attrs: [String : Attr]
let para: ParamType
......@@ -95,27 +100,37 @@ class Operator <ParameterType: OpParam, KernelType: Computable>: OperatorProtoc
outpus = opDesc.outputs
attrs = opDesc.attrs
paraInputs = opDesc.paraInputs
kernel = KerType.init(device: device)
do {
para = try ParamType.init(opDesc:opDesc, inScope: inScope)
} catch let error {
throw error
kernel = KernelType.init(device: device, param: para)
// op infos
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d"
let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"])]
let gFetchType = "fetch"
let gFeedType = "feed"
let gConvType = "conv2d"
let gBatchNormType = "batch_norm"
let gReluType = "relu"
let gElementwiseAdd = "elementwise_add"
let gConvAddBatchNormReluType = "conv_add_batchnorm_relu"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
gBatchNormType : (inputs: ["X"], outputs: ["Y"]),
gReluType : (inputs: ["X"], outputs: ["Out"]),
gElementwiseAdd : (inputs: ["X", "Y"], outputs: ["Out"]),
gFeedType : (inputs: ["X"], outputs: ["Out"]),
gFetchType : (inputs: ["X"], outputs: ["Out"]),
gConvAddBatchNormReluType : (inputs: ["Input"], outputs: ["Out"])]
......@@ -42,7 +42,7 @@ struct BatchNormParam<P: PrecisionType>: OpParam {
let is_test: Bool
class BatchNormOp<P: PrecisionType>: Operator<BatchNormParam<P>, BatchNormKernel<P>>, Runable, Creator, InferShaperable{
class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
// ConvAddBatchNormReluOp.swift
// paddle-mobile
// Created by liuRuiLong on 2018/7/8.
// Copyright © 2018年 orange. All rights reserved.
import Foundation
class ConvAddBatchNormReluOp<P: PrecisionType>: Operator<ConvAddBatchNormReluKernel<P>, ConvParam<P>>, Runable, Creator, InferShaperable, Fusion{
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode
--> Node.init(inType: gElementwiseAdd)
--> Node.init(inType: gBatchNormType)
--> Node.init(inType: gReluType)
return beginNode
static func change() -> [String : [(from: String, to: String)]] {
return [:]
typealias OpType = ConvAddBatchNormReluOp<P>
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
let strides = para.stride
let paddings = para.paddings
let dilations = para.dilations
var outDim = [inDims[0]]
for i in 0..<strides.count {
let dilation: Int = Int(dilations[i])
let filterSize: Int = filterDim[i + 1]
let inputSize: Int = inDims[i + 1]
let padding: Int = Int(paddings[i])
let stride: Int = Int(strides[i])
let dKernel = dilation * (filterSize - 1) + 1
let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
para.output.dim = Dim.init(inDim: outDim)
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
......@@ -39,7 +39,15 @@ struct ConvParam<P: PrecisionType>: OpParam {
let groups: Int
class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable, Creator, InferShaperable {
class ConvOp<P: PrecisionType>: Operator<ConvKernel<P>, ConvParam<P>>, Runable, Creator, InferShaperable {
required init(device: MTLDevice, opDesc: OpDesc, inScope: Scope) throws {
do {
try super.init(device: device, opDesc: opDesc, inScope: inScope)
} catch let error {
throw error
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
......@@ -69,7 +77,6 @@ class ConvOp<P: PrecisionType>: Operator<ConvParam<P>, ConvKernel<P>>, Runable,
} catch let error {
throw error
func delogOutput() {
......@@ -32,7 +32,7 @@ struct ElementwiseAddParam<P: PrecisionType>: OpParam {
let axis: Int
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddParam<P>, ElementwiseAddKernel<P>>, Runable, Creator, InferShaperable{
class ElementwiseAddOp<P: PrecisionType>: Operator<ElementwiseAddKernel<P>, ElementwiseAddParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
......@@ -33,7 +33,7 @@ struct FeedParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P
class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, Texture2DTo2DArrayKernel<P>>, Runable, Creator, InferShaperable {
class FeedOp<P: PrecisionType>: Operator<Texture2DTo2DArrayKernel<P>, FeedParam<P>>, Runable, Creator, InferShaperable {
typealias OpType = FeedOp<P>
func inferShape() {
......@@ -44,9 +44,8 @@ class FeedOp<P: PrecisionType>: Operator<FeedParam<P>, Texture2DTo2DArrayKernel<
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
let locPara = Texture2DTo2DArrayParam.init(input: para.input.mtlTexture, output: para.output.metalTexture, expectDim: para.input.expectDim)
do {
try kernel.compute(commandBuffer: buffer, param: locPara)
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
......@@ -30,7 +30,17 @@ struct FetchParam<P: PrecisionType>: OpParam{
typealias ParamPrecisionType = P
class FetchOp<P: PrecisionType>: Operator<FetchParam<P>, ResizeKernel<P>>, Runable, Creator, InferShaperable{
class FetchKernel<P: PrecisionType>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: FetchParam<P>) throws {
required init(device: MTLDevice, param: FetchParam<P>) {
super.init(device: device, inFunctionName: "texture2d_to_2d_array")
class FetchOp<P: PrecisionType>: Operator< FetchKernel<P>, FetchParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
......@@ -9,7 +9,7 @@
import Foundation
class BatchNormKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
required init(device: MTLDevice, param: BatchNormParam<P>) {
super.init(device: device, inFunctionName: "batchnorm")
// ConvKernel.swift
// paddle-mobile
// Created by liuRuiLong on 2018/7/5.
// Copyright © 2018年 orange. All rights reserved.
import Foundation
class ConvAddBatchNormReluKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice, param: ConvParam<P>) {
super.init(device: device, inFunctionName: "conv3x3")
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
// ConvKernel.metal
// paddle-mobile
// Created by liuRuiLong on 2018/7/7.
// Copyright © 2018年 orange. All rights reserved.
#include <metal_stdlib>
using namespace metal;
struct MetalConvParam {
short offsetX;
short offsetY;
short offsetZ;
ushort strideX;
ushort strideY;
kernel void conv3x3(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant MetalConvParam &param [[buffer(0)]],
const device half4 *weights [[buffer(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) {
short2 posInInput = short2(gid.xy) + short2(param.offsetX, param.offsetY);
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
const uint wightSliceCount = 36;
uint weithTo = gid.z * wightSliceCount * inTexture.get_array_size();
half4 output = 0.0;
for (uint i = 0; i < inTexture.get_array_size(); ++i) {
half4 input[9];
input[0] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y - 1), i);
input[1] = inTexture.sample(sample, float2(posInInput.x, posInInput.y - 1), i);
input[2] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y - 1), i);
input[3] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y), i);
input[4] = inTexture.sample(sample, float2(posInInput.x, posInInput.y), i);
input[5] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y), i);
input[6] = inTexture.sample(sample, float2(posInInput.x - 1, posInInput.y + 1), i);
input[7] = inTexture.sample(sample, float2(posInInput.x, posInInput.y + 1), i);
input[8] = inTexture.sample(sample, float2(posInInput.x + 1, posInInput.y + 1), i);
for (int j = 0; j < 9; ++j) {
half4 weight = weights[weithTo + wightSliceCount * i + j * 4];
output += dot(input[j], weight);
outTexture.write(output, gid.xy, gid.z);
......@@ -9,20 +9,37 @@
import Foundation
struct MetalConvParam {
let offsetX: Int16
let offsetY: Int16
let offsetZ: Int16
let strideX: UInt16
let strideY: UInt16
let paddedZ: UInt16
class ConvKernel<P: PrecisionType>: Kernel, Computable {
var metalParam: MetalConvParam!
required init(device: MTLDevice, param: ConvParam<P>) {
super.init(device: device, inFunctionName: "conv3x3")
let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
let offsetZ = 0.0
metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
......@@ -10,8 +10,8 @@ import Foundation
class ElementwiseAddKernel<P: PrecisionType>: Kernel, Computable {
required init(device: MTLDevice) {
super.init(device: device, inFunctionName: "conv")
required init(device: MTLDevice, param: ElementwiseAddParam<P>) {
super.init(device: device, inFunctionName: "elementwise_add")
func compute(commandBuffer: MTLCommandBuffer, param: ElementwiseAddParam<P>) throws {
......@@ -16,14 +16,15 @@ import Metal
import Foundation
protocol Computable {
associatedtype ParamType
associatedtype ParamType: OpParam
func compute(commandBuffer: MTLCommandBuffer, param: ParamType) throws
init(device: MTLDevice)
init(device: MTLDevice, param: ParamType)
protocol KernelProtocol {
var pipline: MTLComputePipelineState { get set }
var functionName: String { get set }
class Kernel {
......@@ -60,16 +60,7 @@ kernel void elementwise_add(texture2d_array<half, access::read> inTexture [[text
outTexture.write(input, gid.xy, gid.z);
kernel void conv(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
outTexture.write(input, gid.xy, gid.z);
kernel void batchnorm(texture2d_array<half, access::read> inTexture [[texture(0)]],
......@@ -25,7 +25,7 @@ class ReluKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice) {
required init(device: MTLDevice, param: ReluParam<P>) {
super.init(device: device, inFunctionName: "relu")
......@@ -11,48 +11,52 @@
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
import MetalPerformanceShaders
struct ResizeParam {
let input: MTLTexture
let output: MTLTexture
let expectDim: Dim
struct OutputDim {
let width: UInt16
let height: UInt16
let strideX: UInt16
let strideY: UInt16
class ResizeKernel<P: PrecisionType>: Kernel, Computable{
var lanczos: MPSImageLanczosScale
required init(device: MTLDevice) {
lanczos = MPSImageLanczosScale.init(device: device)
super.init(device: device, inFunctionName: "resize")
func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
// guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
// throw PaddleMobileError.predictError(message: " encode is nil")
// }
lanczos.encode(commandBuffer: commandBuffer, sourceTexture: param.input, destinationTexture: param.output)
// encoder.setTexture(param.input, index: 0)
// encoder.setTexture(param.output, index: 1)
// let strideX = param.input.width/param.expectDim[2]
// let strideY = param.input.height/param.expectDim[1]
// var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
// encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0)
// encoder.dispatch(computePipline: pipline, outTexture: param.output)
// encoder.endEncoding()
//import Foundation
//import MetalPerformanceShaders
//struct ResizeParam: OpParam{
// typealias OutputType = <#type#>
// typealias ParamPrecisionType = <#type#>
// let input: MTLTexture
// let output: MTLTexture
// let expectDim: Dim
//struct OutputDim {
// let width: UInt16
// let height: UInt16
// let strideX: UInt16
// let strideY: UInt16
//class ResizeKernel<P: PrecisionType>: Kernel, Computable{
// var lanczos: MPSImageLanczosScale
// required init(device: MTLDevice, param: ResizeParam) {
// lanczos = MPSImageLanczosScale.init(device: device)
// super.init(device: device, inFunctionName: "resize")
// }
// func compute(commandBuffer: MTLCommandBuffer, param: ResizeParam) throws {
//// guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
//// throw PaddleMobileError.predictError(message: " encode is nil")
//// }
// lanczos.encode(commandBuffer: commandBuffer, sourceTexture: param.input, destinationTexture: param.output)
//// encoder.setTexture(param.input, index: 0)
//// encoder.setTexture(param.output, index: 1)
//// let strideX = param.input.width/param.expectDim[2]
//// let strideY = param.input.height/param.expectDim[1]
//// var outputDim = OutputDim.init(width: UInt16(param.expectDim[1]), height: UInt16(param.expectDim[2]), strideX: UInt16(strideX), strideY: UInt16(strideY))
//// encoder.setBytes(&outputDim, length: MemoryLayout<OutputDim>.size, index: 0)
//// encoder.dispatch(computePipline: pipline, outTexture: param.output)
//// encoder.endEncoding()
// }
......@@ -16,17 +16,17 @@ struct Texture2DTo2DArrayParam {
class Texture2DTo2DArrayKernel<P: PrecisionType>: Kernel, Computable{
func compute(commandBuffer: MTLCommandBuffer, param: Texture2DTo2DArrayParam) throws {
func compute(commandBuffer: MTLCommandBuffer, param: FeedParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
encoder.setTexture(param.input, index: 0)
encoder.setTexture(param.output, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input)
encoder.setTexture(param.input.mtlTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.dispatch(computePipline: pipline, outTexture: param.input.mtlTexture)
required init(device: MTLDevice) {
required init(device: MTLDevice, param: FeedParam<P>) {
super.init(device: device, inFunctionName: "texture2d_to_2d_array")
......@@ -28,7 +28,7 @@ struct ReluParam<P: PrecisionType>: OpParam {
var output: Texture<P>
class ReluOp<P: PrecisionType>: Operator<ReluParam<P>, ReluKernel<P>>, Runable, Creator, InferShaperable{
class ReluOp<P: PrecisionType>: Operator<ReluKernel<P>, ReluParam<P>>, Runable, Creator, InferShaperable{
func inferShape() {
para.output.dim = para.input.dim
......@@ -35,4 +35,33 @@ struct BlockDesc {
self.ops = ops
init(inVars: [VarDesc], inOps: [OpDesc]) {
vars = inVars
ops = inOps
index = 0
parentIndex = 0
extension BlockDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String {
var str = ""
for op in ops {
str += op.description
for varDesc in vars {
str += varDesc.description
return str
var debugDescription: String {
return description
......@@ -16,7 +16,7 @@ import Foundation
struct OpDesc {
let inputs: [String : [String]]
let paraInputs: [String : [String]]
var paraInputs: [String : [String]]
let outputs: [String : [String]]
let unusedOutputs: [String : [String]]
var attrs: [String : Attr] = [:]
......@@ -56,3 +56,26 @@ struct OpDesc {
extension OpDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String {
var str = ""
str += "op type: \(type): \n"
str += " op inputs: \n"
str += " \(inputs) \n"
str += " op para inputs: \n"
str += " \(paraInputs) \n"
str += " op para outputs: \n"
str += " \(outputs) \n"
str += " op attrs: \n"
str += " \(attrs) \n"
return str
var debugDescription: String {
return description
......@@ -21,4 +21,24 @@ public struct ProgramDesc {
self.blocks.append(BlockDesc.init(block: block))
init() {
extension ProgramDesc: CustomStringConvertible, CustomDebugStringConvertible {
public var description: String {
var str: String = ""
for i in 0..<blocks.count {
str += "block - \(i): \n"
str += blocks[i].description
return str
public var debugDescription: String {
return description
// ProgramOptimize.swift
// paddle-mobile
// Created by liuRuiLong on 2018/7/8.
// Copyright © 2018年 orange. All rights reserved.
import Foundation
precedencegroup ChainNode {
associativity: left
higherThan: MultiplicationPrecedence
infix operator --> : ChainNode
class Node {
var inputs: [Node] = []
var outputs: [Node] = []
let type: String
var opDesc: OpDesc?
init(inOpDesc: OpDesc) {
type = inOpDesc.type
opDesc = inOpDesc
init(inType: String) {
type = inType
static func -->(lNode: Node, rNode: Node) -> Node {
return rNode
func depth(begin: UInt = 1) -> UInt {
var beginMax: UInt = 0
for output in outputs {
let subDepth = output.depth(begin: begin + 1)
beginMax = max(begin, subDepth)
return beginMax
func to(depth: UInt) -> Node {
let beginNode = Node.init(inType: type)
to(depth: depth - 1, withNode: beginNode)
return beginNode
func folderWith(fusion: Fusion.Type) {
let fusionNode = fusion.fusionNode()
let change = fusion.change()
let inOutputs = outputs
for i in 0..<inOutputs.count {
inOutputs[i].folderWith(beginNode: self, matchNode: fusionNode.outputs[i], change: change)
private func folderWith(beginNode: Node, matchNode: Node, change: [String : [(from: String, to: String)]]) {
guard let inOpdesc = opDesc else {
for attr in inOpdesc.attrs {
beginNode.opDesc?.attrs[attr.key] = attr.value
for paraInput in inOpdesc.paraInputs {
if let inChanges = change[type] {
for keyChange in inChanges {
if keyChange.from == paraInput.key {
beginNode.opDesc?.paraInputs[keyChange.to] = paraInput.value
} else {
beginNode.opDesc?.paraInputs[paraInput.key] = paraInput.value
} else {
beginNode.opDesc?.paraInputs[paraInput.key] = paraInput.value
if matchNode.outputs.count == 0 {
beginNode.outputs.append(contentsOf: outputs)
private func to(depth: UInt, withNode: Node) {
if depth < 1 {
for output in outputs {
let node = Node.init(inType: output.type)
output.to(depth: depth - 1, withNode: node)
extension Node: Equatable {
static func == (lhs: Node, rhs: Node) -> Bool {
if lhs.outputs.count != rhs.outputs.count {
return false
if lhs.type != rhs.type {
return false
for i in 0..<lhs.outputs.count {
if lhs.outputs[i] != rhs.outputs[i] {
return false
return true
class ProgramOptimize<P: PrecisionType> {
let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp<P>.self]
func optimize(originProgramDesc: ProgramDesc) -> ProgramDesc {
guard originProgramDesc.blocks.count == 1 else {
fatalError(" not support yet")
var mapForNodeChain: [String : Node] = [:]
var nodes: [Node] = []
var typeMapNodes: [String : [Node]] = [:]
let block = originProgramDesc.blocks[0]
for opDesc in block.ops {
guard let opInputKeys = opInfos[opDesc.type]?.inputs, let outputKeys = opInfos[opDesc.type]?.outputs else {
let node = Node.init(inOpDesc: opDesc)
for inputKey in opInputKeys {
if let inputs = opDesc.inputs[inputKey] {
for input in inputs {
if let inputNode = mapForNodeChain[input] {
_ = inputNode --> node
for outputKey in outputKeys {
if let outputs = opDesc.outputs[outputKey] {
for output in outputs {
mapForNodeChain[output] = node
if var nodes = typeMapNodes[opDesc.type] {
typeMapNodes[opDesc.type] = nodes
} else {
typeMapNodes[opDesc.type] = []
for fusion in fusionOps {
let fusionNode = fusion.fusionNode()
let depth = fusionNode.depth()
if let nodes = typeMapNodes[fusionNode.type] {
for node in nodes {
let toNode = node.to(depth: 4)
if toNode == fusionNode { // match
node.folderWith(fusion: fusion)
var ops: [OpDesc] = []
for node in nodes {
var newProgramDesc = ProgramDesc.init()
let newBlock = BlockDesc.init(inVars: block.vars, inOps: ops)
return newProgramDesc
......@@ -76,5 +76,23 @@ struct VarDesc {
tensorDesc = .none
extension VarDesc: CustomStringConvertible, CustomDebugStringConvertible {
var description: String {
var str = ""
str += "var name \(name): \n"
if let inTensorDesc = tensorDesc {
str += " dim size: \(inTensorDesc.dims.count) \n"
str += " dim: \(inTensorDesc.dims) \n"
} else {
str += " no dim info"
return str
var debugDescription: String {
return description
......@@ -29,6 +29,7 @@ extension Tensorial {
class Tensor<P: PrecisionType>: Tensorial {
var data: Data
var dim: Dim
var buffer: MTLBuffer!
private(set) var layout: DataLayout
class Data {
......@@ -37,7 +38,7 @@ class Tensor<P: PrecisionType>: Tensorial {
pointer = inPointer
let size: Int
var pointer: UnsafeMutablePointer<P>
fileprivate var pointer: UnsafeMutablePointer<P>
subscript(index: Int) -> P{
get {
return pointer[index]
......@@ -51,7 +52,7 @@ class Tensor<P: PrecisionType>: Tensorial {
deinit {
// release()
......@@ -87,6 +88,39 @@ class Tensor<P: PrecisionType>: Tensorial {
layout = to
func initBuffer(device: MTLDevice) {
if dim.cout() == 4 {
if layout == .NHWC {
let C = dim[3]
let cSlices = (C + 3) / 4
let paddedC = cSlices * 4
let count = paddedC * dim[0] * dim[1] * dim[2]
buffer = device.makeBuffer(length: count * MemoryLayout<P>.stride)
if C == paddedC {
buffer?.contents().copyMemory(from: data.pointer, byteCount: count * MemoryLayout<P>.stride)
} else {
var tmpPointer = data.pointer
var dstPtr = buffer?.contents().bindMemory(to: P.self, capacity: count)
for _ in 0..<dim[0] * dim[1] * dim[2] {
for j in 0..<paddedC {
if j < C {
dstPtr?[j] = data.pointer[j]
tmpPointer += C
dstPtr! += paddedC
} else if dim.cout() == 1 {
buffer = device.makeBuffer(length: numel() * MemoryLayout<P>.stride)
buffer?.contents().copyMemory(from: data.pointer, byteCount: numel() * MemoryLayout<P>.stride)
} else {
fatalError(" not support !")
func NCHW2NHWC(newPtr: UnsafeMutablePointer<P>) {
let N = dim[0]
let C = dim[1]
......@@ -58,8 +58,13 @@ public class Texture<P: PrecisionType>: Tensorial {
tmpTextureDes.depth = 1
tmpTextureDes.arrayLength = (inDim[3] * inDim[0] + 3)/4
tmpTextureDes.textureType = .type2DArray
} else if inDim.cout() == 2 {
tmpTextureDes.height = inDim[0]
tmpTextureDes.width = inDim[1]
tmpTextureDes.depth = 1
tmpTextureDes.textureType = .type2D
} else {
fatalError(" didn't support yet")
fatalError(" not suuprt ")
if MemoryLayout<P>.size == 1 {
tmpTextureDes.pixelFormat = .rgba8Unorm
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册