提交 980b855d 编写于 作者: Y Yanzhan Yang 提交者: GitHub

Merge pull request #1582 from yangyanzhan/develop

fix MPSCNN ConvAdd implementation.
...@@ -17,8 +17,6 @@ import MetalPerformanceShaders ...@@ -17,8 +17,6 @@ import MetalPerformanceShaders
@available(iOS 10.0, *) @available(iOS 10.0, *)
var convDic: [String : MPSCNNConvolution] = [:] var convDic: [String : MPSCNNConvolution] = [:]
@available(iOS 10.0, *)
var imageDic: [String : MPSImage] = [:]
/// 获取唯一字符串 /// 获取唯一字符串
/// ///
...@@ -117,23 +115,22 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -117,23 +115,22 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
if #available(iOS 11.0, *) { if #available(iOS 11.0, *) {
var desc: MPSCNNConvolutionDescriptor? var desc: MPSCNNConvolutionDescriptor?
// 如果不是 depth wise, 并且输入输出 tensor channel 都大于 4 // 如果不是 depth wise, 并且输入输出 tensor channel 都大于 4
if !(param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]) && param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 { let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3], if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
kernelHeight: param.filter.tensorDim[2], if isDepthWise {
inputFeatureChannels: param.input.tensorDim[1], desc = MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil)
desc?.strideInPixelsX = Int(param.stride[0])
desc?.strideInPixelsY = Int(param.stride[1])
} else if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
desc = MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2], kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1], inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1], outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil) neuronFilter: nil)
} else {
desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil)
}
} }
desc?.strideInPixelsX = Int(param.stride[0]) desc?.strideInPixelsX = Int(param.stride[0])
desc?.strideInPixelsY = Int(param.stride[1]) desc?.strideInPixelsY = Int(param.stride[1])
if let inDesc = desc { if let inDesc = desc {
...@@ -143,8 +140,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -143,8 +140,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0) conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0)
conv.edgeMode = .zero conv.edgeMode = .zero
convDic[key] = conv convDic[key] = conv
imageDic[identifyingKey + "_input"] = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
imageDic[identifyingKey + "_output"] = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
super.init(device: device, inFunctionName: nil, initContext: initContext) super.init(device: device, inFunctionName: nil, initContext: initContext)
return return
} }
...@@ -201,7 +196,9 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -201,7 +196,9 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
if #available(iOS 10.0, *) { if #available(iOS 10.0, *) {
if let conv = convDic[identifyingKey], let inputImage = imageDic[identifyingKey + "_input"], let outputImage = imageDic[identifyingKey + "_output"] { if let conv = convDic[identifyingKey] {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage) conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
return; return;
} }
...@@ -222,8 +219,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -222,8 +219,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
deinit { deinit {
if #available(iOS 10.0, *) { if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey) convDic.removeValue(forKey: identifyingKey)
imageDic.removeValue(forKey: identifyingKey + "_input")
imageDic.removeValue(forKey: identifyingKey + "_output")
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册