BatchNormOp.swift 2.5 KB
Newer Older
L
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
///* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
// http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License. */
L
liuruilong 已提交
14 15 16

import Foundation

L
liuruilong 已提交
17
class BatchNormParam<P: PrecisionType>: OpParam {
L
liuruilong 已提交
18 19 20 21
  typealias ParamPrecisionType = P
  required init(opDesc: OpDesc, inScope: Scope) throws {
    do {
      input = try BatchNormParam.inputX(inputs: opDesc.inputs, from: inScope)
D
dolphin8 已提交
22 23 24
      if input.transpose != [0, 2, 3, 1] {
        fatalError("batch norm only accepts NHWC")
      }
L
liuruilong 已提交
25
      output = try BatchNormParam.outputY(outputs: opDesc.outputs, from: inScope)
D
dolphin8 已提交
26 27 28 29
      bias = try BatchNormParam.getFirstTensor(key: "Bias", map: opDesc.paraInputs, from: inScope)
      mean = try BatchNormParam.getFirstTensor(key: "Mean", map: opDesc.paraInputs, from: inScope)
      scale = try BatchNormParam.getFirstTensor(key: "Scale", map: opDesc.paraInputs, from: inScope)
      variance = try BatchNormParam.getFirstTensor(key: "Variance", map: opDesc.paraInputs, from: inScope)
L
liuruilong 已提交
30 31 32 33
      epsilon = try BatchNormParam.getAttr(key: "epsilon", attrs: opDesc.attrs)
      momentum = try BatchNormParam.getAttr(key: "momentum", attrs: opDesc.attrs)
    } catch let error {
      throw error
L
liuruilong 已提交
34
    }
L
liuruilong 已提交
35 36 37
  }
  let input: Texture<P>
  var output: Texture<P>
D
dolphin8 已提交
38 39 40 41
  let bias: Tensor<P>
  let mean: Tensor<P>
  let scale: Tensor<P>
  let variance: Tensor<P>
L
liuruilong 已提交
42 43
  let epsilon: Float
  let momentum: Float
L
liuruilong 已提交
44 45
}

L
liuruilong 已提交
46
class BatchNormOp<P: PrecisionType>: Operator<BatchNormKernel<P>, BatchNormParam<P>>, Runable, Creator, InferShaperable{
L
update  
liuruilong 已提交
47 48
  typealias OpType = BatchNormOp<P>

L
liuruilong 已提交
49 50 51 52 53 54 55 56
  func inferShape() {
    para.output.dim = para.input.dim
  }
  func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
    do {
      try kernel.compute(commandBuffer: buffer, param: para)
    } catch let error {
      throw error
L
liuruilong 已提交
57
    }
L
liuruilong 已提交
58
  }
D
xx  
dolphin8 已提交
59 60 61 62 63 64 65
  
  func delogOutput() {
    print(" \(type) output: ")
    let device = para.output.metalTexture!.device
    let outputArray: [Float32] = device.texture2tensor(texture: para.output.metalTexture, dim: para.output.tensorDim.dims, transpose: para.output.transpose)
    print(outputArray.strideArray())
  }
L
liuruilong 已提交
66
}