提交 200be579 编写于 作者: Y yangyanzhan

fuse Conv-Add-Relu into one op.

上级 99cccf3b
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
objects = { objects = {
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; };
5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; }; 5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; };
FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; }; FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; };
FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBD221E69DD00DC3CB2 /* BoxCoder.metal */; }; FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBD221E69DD00DC3CB2 /* BoxCoder.metal */; };
...@@ -52,6 +53,7 @@ ...@@ -52,6 +53,7 @@
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = "<group>"; };
33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; }; 33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; };
5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; };
C6D31B9F9533810DBCA6B28D /* Pods-paddle-mobile-metallib.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.debug.xcconfig"; sourceTree = "<group>"; }; C6D31B9F9533810DBCA6B28D /* Pods-paddle-mobile-metallib.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.debug.xcconfig"; sourceTree = "<group>"; };
...@@ -190,6 +192,7 @@ ...@@ -190,6 +192,7 @@
FCC15DBF221E69DD00DC3CB2 /* Split.metal */, FCC15DBF221E69DD00DC3CB2 /* Split.metal */,
FCC15DC9221E69DE00DC3CB2 /* TransposeKernel.inc.metal */, FCC15DC9221E69DE00DC3CB2 /* TransposeKernel.inc.metal */,
FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */, FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */,
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */,
); );
path = "paddle-mobile-metallib"; path = "paddle-mobile-metallib";
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -310,6 +313,7 @@ ...@@ -310,6 +313,7 @@
FCC15E08221E69E100DC3CB2 /* Split.inc.metal in Sources */, FCC15E08221E69E100DC3CB2 /* Split.inc.metal in Sources */,
FCC15DF4221E69E100DC3CB2 /* ResizeBilinear.metal in Sources */, FCC15DF4221E69E100DC3CB2 /* ResizeBilinear.metal in Sources */,
FCC15E05221E69E100DC3CB2 /* BatchNormKernel.metal in Sources */, FCC15E05221E69E100DC3CB2 /* BatchNormKernel.metal in Sources */,
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */,
FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */, FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */,
FCC15DF6221E69E100DC3CB2 /* PoolKernel.metal in Sources */, FCC15DF6221E69E100DC3CB2 /* PoolKernel.metal in Sources */,
FCC15E09221E69E100DC3CB2 /* ConcatKernel.inc.metal in Sources */, FCC15E09221E69E100DC3CB2 /* ConcatKernel.inc.metal in Sources */,
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
objects = { objects = {
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */; };
165F38D52276CE7D0088E29F /* ConvAddReluKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */; };
456BB7B421F5B356001474E2 /* Framework.pbobjc.m in Sources */ = {isa = PBXBuildFile; fileRef = 456BB7B221F5B356001474E2 /* Framework.pbobjc.m */; settings = {COMPILER_FLAGS = "-fno-objc-arc"; }; }; 456BB7B421F5B356001474E2 /* Framework.pbobjc.m in Sources */ = {isa = PBXBuildFile; fileRef = 456BB7B221F5B356001474E2 /* Framework.pbobjc.m */; settings = {COMPILER_FLAGS = "-fno-objc-arc"; }; };
456BB7B521F5B356001474E2 /* Framework.pbobjc.h in Headers */ = {isa = PBXBuildFile; fileRef = 456BB7B321F5B356001474E2 /* Framework.pbobjc.h */; settings = {ATTRIBUTES = (Public, ); }; }; 456BB7B521F5B356001474E2 /* Framework.pbobjc.h in Headers */ = {isa = PBXBuildFile; fileRef = 456BB7B321F5B356001474E2 /* Framework.pbobjc.h */; settings = {ATTRIBUTES = (Public, ); }; };
4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */; }; 4AA1EA862146625E00D0F791 /* BilinearInterpOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */; };
...@@ -101,6 +103,8 @@ ...@@ -101,6 +103,8 @@
/* End PBXBuildFile section */ /* End PBXBuildFile section */
/* Begin PBXFileReference section */ /* Begin PBXFileReference section */
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluOp.swift; sourceTree = "<group>"; };
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConvAddReluKernel.swift; sourceTree = "<group>"; };
456BB7B221F5B356001474E2 /* Framework.pbobjc.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = Framework.pbobjc.m; sourceTree = "<group>"; }; 456BB7B221F5B356001474E2 /* Framework.pbobjc.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = Framework.pbobjc.m; sourceTree = "<group>"; };
456BB7B321F5B356001474E2 /* Framework.pbobjc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Framework.pbobjc.h; sourceTree = "<group>"; }; 456BB7B321F5B356001474E2 /* Framework.pbobjc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Framework.pbobjc.h; sourceTree = "<group>"; };
4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpOp.swift; sourceTree = "<group>"; }; 4AA1EA852146625E00D0F791 /* BilinearInterpOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BilinearInterpOp.swift; sourceTree = "<group>"; };
...@@ -323,6 +327,7 @@ ...@@ -323,6 +327,7 @@
FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */, FC803BBE214CB65A0094B8E5 /* ConvAddPreluOp.swift */,
FCE3A1A82153DE5100C37CDE /* ConvAddAddPreluOp.swift */, FCE3A1A82153DE5100C37CDE /* ConvAddAddPreluOp.swift */,
FCE3A1AC2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift */, FCE3A1AC2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift */,
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */,
); );
path = Operators; path = Operators;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -377,6 +382,7 @@ ...@@ -377,6 +382,7 @@
FCE3A1AE2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift */, FCE3A1AE2153E8EE00C37CDE /* ElementwiseAddPreluKernel.swift */,
FC2BFD4521DF685F00C262B2 /* Scale.swift */, FC2BFD4521DF685F00C262B2 /* Scale.swift */,
FCB40E5821E0DCAB0075EC91 /* FetchKernel.swift */, FCB40E5821E0DCAB0075EC91 /* FetchKernel.swift */,
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -541,6 +547,7 @@ ...@@ -541,6 +547,7 @@
FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */, FCBCCC5B2122F66F00D94F7E /* ConvBNReluKernel.swift in Sources */,
4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */, 4AA1EA8C2146640900D0F791 /* SplitOp.swift in Sources */,
FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */, FC0E2DC020EE461F009C1FAC /* ElementwiseAddKernel.swift in Sources */,
165F38D52276CE7D0088E29F /* ConvAddReluKernel.swift in Sources */,
FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */, FC803BBF214CB65A0094B8E5 /* ConvAddPreluOp.swift in Sources */,
FCEB684C212F093800D2448E /* PreluOp.swift in Sources */, FCEB684C212F093800D2448E /* PreluOp.swift in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */, FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
...@@ -592,6 +599,7 @@ ...@@ -592,6 +599,7 @@
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */, FC039BBD20E11CC20081E9F8 /* Program.swift in Sources */,
FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */, FC039BA220E11CB70081E9F8 /* Loader.swift in Sources */,
165F38D32276CDEA0088E29F /* ConvAddReluOp.swift in Sources */,
FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */,
FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */,
FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */,
......
...@@ -68,7 +68,8 @@ class OpCreator<P: PrecisionProtocol> { ...@@ -68,7 +68,8 @@ class OpCreator<P: PrecisionProtocol> {
gConvAddPreluType : ConvAddPreluOp<P>.creat, gConvAddPreluType : ConvAddPreluOp<P>.creat,
gConvAddAddPreluType : ConvAddAddPreluOp<P>.creat, gConvAddAddPreluType : ConvAddAddPreluOp<P>.creat,
gElementwiseAddPreluType : ElementwiseAddPreluOp<P>.creat, gElementwiseAddPreluType : ElementwiseAddPreluOp<P>.creat,
gFusionConvAddType : ConvAddOp<P>.creat] gFusionConvAddType : ConvAddOp<P>.creat,
gConvAddReluType : ConvAddReluOp<P>.creat]
private init(){} private init(){}
} }
...@@ -173,6 +173,7 @@ let gBilinearInterpType = "bilinear_interp" ...@@ -173,6 +173,7 @@ let gBilinearInterpType = "bilinear_interp"
let gSplit = "split" let gSplit = "split"
let gShape = "shape" let gShape = "shape"
let gFlatten = "flatten" let gFlatten = "flatten"
let gConvAddReluType = "conv_add_relu"
let gConvAddPreluType = "conv_add_prelu" let gConvAddPreluType = "conv_add_prelu"
let gConvAddAddPreluType = "conv_add_add_prelu" let gConvAddAddPreluType = "conv_add_add_prelu"
let gElementwiseAddPreluType = "elementwise_add_prelu" let gElementwiseAddPreluType = "elementwise_add_prelu"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ConvAddReluOp<P: PrecisionProtocol>: Operator<ConvAddReluKernel<P>, ConvAddParam<P>>, Runable, Creator, InferShaperable, Fusion {
typealias OpType = ConvAddReluOp<P>
static func fusionNode() -> Node {
let beginNode = Node.init(inType: gConvType)
_ = beginNode
--> Node.init(inType: gElementwiseAddType)
--> Node.init(inType: gReluType)
return beginNode
}
static func change() -> [String : [(from: String, to: String)]] {
return [:]
}
static func fusionType() -> String {
return gConvAddReluType
}
func inferShape() {
let inDims = para.input.dim
let filterDim = para.filter.dim
let strides = para.stride
let paddings = para.paddings
let dilations = para.dilations
var outDim = [inDims[0]]
for i in 0..<strides.count {
let dilation: Int = Int(dilations[i])
let filterSize: Int = filterDim[i + 1]
let inputSize: Int = inDims[i + 1]
let padding: Int = Int(paddings[i])
let stride: Int = Int(strides[i])
let dKernel = dilation * (filterSize - 1) + 1
let outputSize = (inputSize + 2 * padding - dKernel) / stride + 1
outDim.append(outputSize)
}
outDim.append(filterDim[0])
para.output.dim = Dim.init(inDim: outDim)
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
...@@ -103,123 +103,135 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -103,123 +103,135 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let identifyingKey: String = getUniqueKey() let identifyingKey: String = getUniqueKey()
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) { required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) {
param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision) param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1)/2 - Int(param.paddings[1]) var shouldUseMPS = false
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1)/2 - Int(param.paddings[0]) if #available(iOS 11.0, *), initContext.useMPS {
// 输入输出 tensor channel 必须都大于 4
if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
shouldUseMPS = true
}
}
if shouldUseMPS {
super.init(device: device, inFunctionName: nil, initContext: initContext)
setupWithMPS(device: device, param: param)
} else {
let functionName = type(of: self).kernelFunctionName(param: param)
if functionName == nil {
fatalError(" unsupport yet ")
}
super.init(device: device, inFunctionName: functionName, initContext: initContext)
setupWithoutMPS(device: device, param: param)
}
}
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
if #available(iOS 10.0, *) {
if let conv = convDic[identifyingKey] {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
return
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
encoder.setBuffer(param.y.buffer, offset: 0, index: 2)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
deinit {
if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey)
}
}
func setupWithMPS(device: MTLDevice, param: ConvAddParam<P>) {
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1) / 2 - Int(param.paddings[0])
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1) / 2 - Int(param.paddings[1])
let key = identifyingKey let key = identifyingKey
if initContext.useMPS { // 使用 apple 的 MetalPerformanceShaders
if #available(iOS 11.0, *) {
var desc: MPSCNNConvolutionDescriptor?
// 如果不是 depth wise, 并且输入输出 tensor channel 都大于 4
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1] let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 { if #available(iOS 11.0, *) {
if isDepthWise { let desc: MPSCNNConvolutionDescriptor = isDepthWise ?
desc = MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3], MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2], kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1], inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1], outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil) neuronFilter: neuronFilterForMPSLayer(device: device) as? MPSCNNNeuron) :
} else { MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2], kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1], inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1], outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil) neuronFilter: neuronFilterForMPSLayer(device: device) as? MPSCNNNeuron)
} desc.strideInPixelsX = Int(param.stride[0])
} desc.strideInPixelsY = Int(param.stride[1])
desc?.strideInPixelsX = Int(param.stride[0])
desc?.strideInPixelsY = Int(param.stride[1])
if let inDesc = desc {
let _ = param.filter.convert(converter: MPSPointerConverter<P>.init()) let _ = param.filter.convert(converter: MPSPointerConverter<P>.init())
let dataSource = ConvDataSource.init(inDesc: inDesc, inWeights: param.filter, inBiasTerms: param.y) let dataSource = ConvDataSource.init(inDesc: desc, inWeights: param.filter, inBiasTerms: param.y)
let conv = MPSCNNConvolution.init(device: device, weights: dataSource) let conv = MPSCNNConvolution.init(device: device, weights: dataSource)
conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0) conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0)
conv.edgeMode = .zero conv.edgeMode = .zero
convDic[key] = conv convDic[key] = conv
super.init(device: device, inFunctionName: nil, initContext: initContext)
return
}
} }
} }
func setupWithoutMPS(device: MTLDevice, param: ConvAddParam<P>) {
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1) / 2 - Int(param.paddings[0])
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1) / 2 - Int(param.paddings[1])
let offsetZ = 0.0
let inMetalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), dilationX: UInt16(param.dilations[0]), dilationY: UInt16(param.dilations[1]))
metalParam = inMetalParam
let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1]) let padWhenOneC = !(param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1])
param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, padWhenOneC: padWhenOneC) param.filter.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision, padWhenOneC: padWhenOneC)
param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision) param.y.initBuffer(device: device, precision: GlobalConfig.shared.computePrecision)
}
open class func kernelFunctionName(param: ConvAddParam<P>) -> String? {
if GlobalConfig.shared.computePrecision == .Float16 { if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1_half", initContext: initContext) return "conv_add_1x1_half"
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3_half", initContext: initContext) return "depthwise_conv_add_3x3_half"
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_3x3_half", initContext: initContext) return "conv_add_3x3_half"
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1_half", initContext: initContext) return "conv_add_5x1_half"
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x5_half", initContext: initContext) return "conv_add_1x5_half"
} else { } else {
fatalError(" unsupport yet ") return nil
} }
} else if GlobalConfig.shared.computePrecision == .Float32 { } else if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 { if param.filter.width == 1 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x1", initContext: initContext) return "conv_add_1x1"
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] { } else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
super.init(device: device, inFunctionName: "depthwise_conv_add_3x3", initContext: initContext) return "depthwise_conv_add_3x3"
} else if param.filter.width == 1 && param.filter.height == 5 { } else if param.filter.width == 1 && param.filter.height == 5 {
super.init(device: device, inFunctionName: "conv_add_5x1", initContext: initContext) return "conv_add_5x1"
} else if param.filter.width == 5 && param.filter.height == 1 { } else if param.filter.width == 5 && param.filter.height == 1 {
super.init(device: device, inFunctionName: "conv_add_1x5", initContext: initContext) return "conv_add_1x5"
} else if param.filter.width == 3 && param.filter.height == 3 { } else if param.filter.width == 3 && param.filter.height == 3 {
super.init(device: device, inFunctionName: "conv_add_3x3", initContext: initContext) return "conv_add_3x3"
} else { } else {
fatalError(" unsupport yet ") return nil
} }
} else { } else {
fatalError() return nil
} }
// print(" function: \(functionName)")
// print("offset x: \(offsetX)")
// print("offset y: \(offsetY)")
let offsetZ = 0.0
let inMetalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), dilationX: UInt16(param.dilations[0]), dilationY: UInt16(param.dilations[1]))
// print("metal param: ")
// print(inMetalParam)
metalParam = inMetalParam
} }
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws { func neuronFilterForMPSLayer(device: MTLDevice) -> AnyObject? {
if #available(iOS 10.0, *) { return nil
if let conv = convDic[identifyingKey] {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
return;
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
encoder.setBuffer(param.y.buffer, offset: 0, index: 2)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
deinit {
if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey)
}
} }
} }
//
// ConvAddReluKernel.swift
// paddle-mobile
//
// Created by Yang,Yanzhan on 2019/4/29.
// Copyright © 2019 orange. All rights reserved.
//
import Foundation
import MetalPerformanceShaders
class ConvAddReluKernel<P: PrecisionProtocol>: ConvAddKernel<P> {
override class func kernelFunctionName(param: ConvAddParam<P>) -> String? {
if GlobalConfig.shared.computePrecision == .Float16 {
if param.filter.width == 1 && param.filter.height == 1 {
return "conv_add_relu_1x1_half"
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
return "depthwise_conv_add_relu_3x3_half"
} else if param.filter.width == 3 && param.filter.height == 3 {
return "conv_add_relu_3x3_half"
} else if param.filter.width == 1 && param.filter.height == 5 {
return "conv_add_relu_5x1_half"
} else if param.filter.width == 5 && param.filter.height == 1 {
return "conv_add_relu_1x5_half"
} else {
return nil
}
} else if GlobalConfig.shared.computePrecision == .Float32 {
if param.filter.width == 1 && param.filter.height == 1 {
return "conv_add_relu_1x1"
} else if param.filter.channel == 1 && param.filter.n == param.input.tensorDim[1] {
return "depthwise_conv_add_relu_3x3"
} else if param.filter.width == 1 && param.filter.height == 5 {
return "conv_add_relu_5x1"
} else if param.filter.width == 5 && param.filter.height == 1 {
return "conv_add_relu_1x5"
} else if param.filter.width == 3 && param.filter.height == 3 {
return "conv_add_relu_3x3"
} else {
return nil
}
} else {
return nil
}
}
override func neuronFilterForMPSLayer(device: MTLDevice) -> AnyObject? {
if #available(iOS 10.0, *) {
return MPSCNNNeuronReLU(device: device, a: 0)
}
return nil
}
}
...@@ -184,6 +184,7 @@ extension Node: Equatable { ...@@ -184,6 +184,7 @@ extension Node: Equatable {
class ProgramOptimize<P: PrecisionProtocol> { class ProgramOptimize<P: PrecisionProtocol> {
// register fusion // register fusion
let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp<P>.self, let fusionOps: [Fusion.Type] = [ConvAddBatchNormReluOp<P>.self,
ConvAddReluOp<P>.self,
// ConvAddAddPreluOp<P>.self, // ConvAddAddPreluOp<P>.self,
ConvAddPreluOp<P>.self, ConvAddPreluOp<P>.self,
ConvAddOp<P>.self, ConvAddOp<P>.self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册