提交 bdf037cf 编写于 作者: Y Yanzhan Yang 提交者: GitHub

remove thread unsafe dictionary (#1617)

上级 c3c6cbd3
......@@ -15,16 +15,6 @@
import Foundation
import MetalPerformanceShaders
@available(iOS 10.0, *)
var convDic: [String : MPSCNNConvolution] = [:]
/// 获取唯一字符串
///
/// - Returns: 唯一字符串
func getUniqueKey() -> String {
return UUID.init().uuidString
}
@available(iOS 11.0, *)
class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSource {
......@@ -99,8 +89,7 @@ class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSourc
class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam!
let identifyingKey: String = getUniqueKey()
var mpsConvOp: Any?
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) throws {
do {
......@@ -139,7 +128,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
if #available(iOS 10.0, *) {
if let conv = convDic[identifyingKey] {
if let conv = mpsConvOp as? MPSCNNConvolution {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
......@@ -158,18 +147,10 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
encoder.endEncoding()
}
deinit {
if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey)
}
}
func setupWithMPS(device: MTLDevice, param: ConvAddParam<P>) {
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1) / 2 - Int(param.paddings[0])
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1) / 2 - Int(param.paddings[1])
let key = identifyingKey
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if #available(iOS 11.0, *) {
param.input.useMPS = true
......@@ -192,7 +173,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let conv = MPSCNNConvolution.init(device: device, weights: dataSource)
conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0)
conv.edgeMode = .zero
convDic[key] = conv
mpsConvOp = conv
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册