提交 bdf037cf 编写于 作者: Y Yanzhan Yang 提交者: GitHub

remove thread unsafe dictionary (#1617)

上级 c3c6cbd3
...@@ -15,16 +15,6 @@ ...@@ -15,16 +15,6 @@
import Foundation import Foundation
import MetalPerformanceShaders import MetalPerformanceShaders
@available(iOS 10.0, *)
var convDic: [String : MPSCNNConvolution] = [:]
/// 获取唯一字符串
///
/// - Returns: 唯一字符串
func getUniqueKey() -> String {
return UUID.init().uuidString
}
@available(iOS 11.0, *) @available(iOS 11.0, *)
class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSource { class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSource {
...@@ -99,8 +89,7 @@ class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSourc ...@@ -99,8 +89,7 @@ class ConvDataSource<P: PrecisionProtocol>: NSObject, MPSCNNConvolutionDataSourc
class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: MetalConvParam! var metalParam: MetalConvParam!
var mpsConvOp: Any?
let identifyingKey: String = getUniqueKey()
required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) throws { required init(device: MTLDevice, param: ConvAddParam<P>, initContext: InitContext) throws {
do { do {
...@@ -139,7 +128,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -139,7 +128,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws { func compute(commandBuffer: MTLCommandBuffer, param: ConvAddParam<P>) throws {
if #available(iOS 10.0, *) { if #available(iOS 10.0, *) {
if let conv = convDic[identifyingKey] { if let conv = mpsConvOp as? MPSCNNConvolution {
let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1]) let inputImage = MPSImage.init(texture: param.input.metalTexture, featureChannels: param.input.tensorDim[1])
let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1]) let outputImage = MPSImage.init(texture: param.output.metalTexture, featureChannels: param.output.tensorDim[1])
conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage) conv.encode(commandBuffer: commandBuffer, sourceImage: inputImage, destinationImage: outputImage)
...@@ -158,18 +147,10 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -158,18 +147,10 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
encoder.endEncoding() encoder.endEncoding()
} }
deinit {
if #available(iOS 10.0, *) {
convDic.removeValue(forKey: identifyingKey)
}
}
func setupWithMPS(device: MTLDevice, param: ConvAddParam<P>) { func setupWithMPS(device: MTLDevice, param: ConvAddParam<P>) {
let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1) / 2 - Int(param.paddings[0]) let offsetX = (Int(param.dilations[0]) * (param.filter.tensorDim[3] - 1) + 1) / 2 - Int(param.paddings[0])
let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1) / 2 - Int(param.paddings[1]) let offsetY = (Int(param.dilations[1]) * (param.filter.tensorDim[2] - 1) + 1) / 2 - Int(param.paddings[1])
let key = identifyingKey
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1] let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if #available(iOS 11.0, *) { if #available(iOS 11.0, *) {
param.input.useMPS = true param.input.useMPS = true
...@@ -192,7 +173,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -192,7 +173,7 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
let conv = MPSCNNConvolution.init(device: device, weights: dataSource) let conv = MPSCNNConvolution.init(device: device, weights: dataSource)
conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0) conv.offset = MPSOffset.init(x: offsetX, y: offsetY, z: 0)
conv.edgeMode = .zero conv.edgeMode = .zero
convDic[key] = conv mpsConvOp = conv
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册