ConvKernel.swift 2.2 KB
Newer Older
L
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
 http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
L
liuruilong 已提交
14 15 16 17

import Foundation


L
liuruilong 已提交
18
public struct MetalConvParam {
L
liuruilong 已提交
19 20 21 22 23 24 25 26
    let offsetX: Int16
    let offsetY: Int16
    let offsetZ: Int16
    let strideX: UInt16
    let strideY: UInt16
    let paddedZ: UInt16
}

L
liuruilong 已提交
27
class ConvKernel<P: PrecisionType>: Kernel, Computable {
L
liuruilong 已提交
28 29
    var metalParam: MetalConvParam!
    required init(device: MTLDevice, param: ConvParam<P>) {
L
liuruilong 已提交
30
        super.init(device: device, inFunctionName: "conv_add_1x1")
L
liuruilong 已提交
31 32 33
        let offsetX = param.filter.dim[2]/2 - Int(param.paddings[0])
        let offsetY = param.filter.dim[1]/2 - Int(param.paddings[1])
        let offsetZ = 0.0
L
liuruilong 已提交
34
        param.filter.initBuffer(device: device, precision: Tensor.BufferPrecision.Float32)
L
liuruilong 已提交
35 36 37 38
        
        metalParam = MetalConvParam.init(offsetX: Int16(offsetX), offsetY: Int16(offsetY), offsetZ: Int16(offsetZ), strideX: UInt16(param.stride[0]), strideY: UInt16(param.stride[1]), paddedZ: UInt16(param.input.metalTexture.arrayLength * 4 - param.input.dim[3]))
    }
    
L
liuruilong 已提交
39
    func compute(commandBuffer: MTLCommandBuffer, param: ConvParam<P>) throws {
L
liuruilong 已提交
40 41 42
        guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
            throw PaddleMobileError.predictError(message: " encode is nil")
        }
L
liuruilong 已提交
43
        
L
liuruilong 已提交
44 45
        encoder.setTexture(param.input.metalTexture, index: 0)
        encoder.setTexture(param.output.metalTexture, index: 1)
L
liuruilong 已提交
46 47
        encoder.setBytes(&metalParam, length: MemoryLayout<MetalConvParam>.size, index: 0)
        encoder.setBuffer(param.filter.buffer, offset: 0, index: 1)
L
liuruilong 已提交
48 49
        encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
        encoder.endEncoding()
L
liuruilong 已提交
50 51
    }
}