未验证 提交 d9b93ad9 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1. reshape output transpose should be 0, 1, 2, 3. 2.mps conv should be...

1. reshape output transpose should be 0, 1, 2, 3. 2.mps conv should be disabled when groups is larger than 1. 3. slice should re-tranpose its input if needed. (#1646)
上级 fea601e9
......@@ -114,6 +114,10 @@ class ConvAddKernel<P: PrecisionProtocol>: Kernel, Computable {
if type(of: self).isWinoGrad(functionName: functionName) {
shouldUseMPS = false
}
let isDepthWise = param.filter.tensorDim[1] == 1 && param.filter.tensorDim[0] == param.input.tensorDim[1]
if !isDepthWise && param.groups > 1 {
shouldUseMPS = false
}
if shouldUseMPS {
super.init(device: device, inFunctionName: nil, initContext: initContext)
setupWithMPS(device: device, param: param)
......
......@@ -34,7 +34,7 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
......
......@@ -29,11 +29,24 @@ public struct SliceMetalParam {
class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
var metalParam: SliceMetalParam
var device: MTLDevice?
var initContext: InitContext?
func compute(commandBuffer: MTLCommandBuffer, param: SliceParam<P>) throws {
let expectedTranspose = [0, 2, 3, 1]
var input = param.input
if param.input.transpose != expectedTranspose {
if let device = device, let initContext = initContext, let transposedInput = encodeTransposeInput(input: param.input, toTranspose: expectedTranspose, commandBuffer: commandBuffer, device: device, initContext: initContext) {
input = transposedInput
} else {
print("input transpose failed in slice kernel")
}
}
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<SliceMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
......@@ -72,7 +85,9 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "slice_half", initContext: initContext)
} else {
fatalError()
fatalError("unknown computePrecision")
}
self.device = device
self.initContext = initContext
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册