提交 980b855d 编写于 作者: Y Yanzhan Yang 提交者: GitHub

Merge pull request #1582 from yangyanzhan/develop

fix MPSCNN ConvAdd implementation.
......@@ -17,8 +17,6 @@ import MetalPerformanceShaders
@available(iOS 10.0, *)
var convDic: [String : MPSCNNConvolution] = [:]
@available(iOS 10.0, *)
var imageDic: [String : MPSImage] = [:]
/// 获取唯一字符串
///
......@@ -117,23 +115,22 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
if #available(iOS 11.0, *) {
var desc: MPSCNNConvolutionDescriptor?
// 如果不是 depth wise, 并且输入输出 tensor channel 都大于 4
if !(param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]) && param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
if isDepthWise {
desc = MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil)
desc?.strideInPixelsX = Int(param.stride[0])
desc?.strideInPixelsY = Int(param.stride[1])
} else if param.input.tensorDim[1] > 4 && param.output.tensorDim[1] > 4 {
desc = MPSCNNDepthWiseConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
} else {
desc = MPSCNNConvolutionDescriptor(kernelWidth: param.filter.tensorDim[3],
kernelHeight: param.filter.tensorDim[2],
inputFeatureChannels: param.input.tensorDim[1],
outputFeatureChannels: param.output.tensorDim[1],
neuronFilter: nil)
}
}
desc?.strideInPixelsX = Int(param.stride[0])
desc?.strideInPixelsY = Int(param.stride[1])
if let inDesc = desc {
......@@ -143,8 +140,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0)
conv.edgeMode = .zero
convDic[key] = conv
imageDic[identifyingKey + "_input"] = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
imageDic[identifyingKey + "_output"] = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
super.init(device: device, inFunctionName: nil, initContext: initContext)
return
}
......@@ -201,7 +196,9 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
if #available(iOS 10.0, *) {
if let conv = convDic[identifyingKey], let inputImage = imageDic[identifyingKey + "_input"], let outputImage = imageDic[identifyingKey + "_output"] {
if let conv = convDic[identifyingKey] {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
return;
}
......@@ -222,8 +219,6 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
deinit {
if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey)
imageDic.removeValue(forKey: identifyingKey + "_input")
imageDic.removeValue(forKey: identifyingKey + "_output")
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册