BatchNormOp.swift 2.4 KB
Newer Older
L
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
// http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
L
liuruilong 已提交
14 15 16

import Foundation

L
liuruilong 已提交
17
class BatchNormParam<P: PrecisionType>: OpParam {
L
liuruilong 已提交
18
    typealias ParamPrecisionType = P
L
liuruilong 已提交
19
    required init(opDesc: OpDesc, inScope: Scope) throws {
L
liuruilong 已提交
20
        do {
L
liuruilong 已提交
21 22 23 24 25 26
            input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
            output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
            inputBias = try BatchNormParam.inputBiase(inputs: opDesc.paraInputs, from: inScope)
            inputMean = try BatchNormParam.inputMean(inputs: opDesc.paraInputs, from: inScope)
            inputScale = try BatchNormParam.inputScale(inputs: opDesc.paraInputs, from: inScope)
            inputVariance = try BatchNormParam.inputVariance(inputs: opDesc.paraInputs, from: inScope)
L
liuruilong 已提交
27 28 29 30 31 32 33
            epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
            momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
            is_test = try BatchNormParam.getAttr(key: "is_test", attrs: opDesc.attrs)
        } catch let error {
            throw error
        }
    }
L
liuruilong 已提交
34 35
    let input: Texture<P>
    var output: Texture<P>
L
liuruilong 已提交
36 37 38 39
    let inputBias: Tensor<ParamPrecisionType>
    let inputMean: Tensor<ParamPrecisionType>
    let inputScale: Tensor<ParamPrecisionType>
    let inputVariance: Tensor<ParamPrecisionType>
L
liuruilong 已提交
40 41 42 43 44
    let epsilon: Float
    let momentum: Float
    let is_test: Bool
}

L
liuruilong 已提交
45
class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam<P>>, Runable, Creator, InferShaperable{
L
liuruilong 已提交
46 47 48
    func inferShape() {
        para.output.dim = para.input.dim
    }
L
liuruilong 已提交
49
    typealias OpType = BatchNormOp<P>
L
liuruilong 已提交
50
    func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
D
dolphin8 已提交
51 52 53 54 55
        do {
            try kernel.compute(commandBuffer: buffer, param: para)
        } catch let error {
            throw error
        }
L
liuruilong 已提交
56
    }
L
liuruilong 已提交
57
}
L
liuruilong 已提交
58 59 60 61 62