PreluOp.swift 2.4 KB
Newer Older
L
update  
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
 http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License. */
14 15 16 17

import Foundation

class PreluParam<P: PrecisionType>: OpParam {
X
xiaohaichun 已提交
18
  //typealias ParamPrecisionType = P
19
  required init(opDesc: PMOpDesc, inScope: Scope) throws {
20 21 22
    do {
      input = try PreluParam.inputX(inputs: opDesc.inputs, from: inScope)
      output = try PreluParam.outputOut(outputs: opDesc.outputs, from: inScope)
L
liuruilong 已提交
23
      alpha = try PreluParam.paramInputAlpha(inputs: opDesc.paraInputs, from: inScope)
24 25 26 27 28 29 30
      mode = try PreluParam.getAttr(key: "mode", attrs: opDesc.attrs)
    } catch let error {
      throw error
    }
  }
  let mode: String
  let alpha: Tensor<P>
L
liuruilong 已提交
31 32
  let input: Texture
  var output: Texture
33 34 35 36
}

class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runable, Creator, InferShaperable{
  
L
update  
liuruilong 已提交
37 38
  typealias OpType = PreluOp<P>

39 40 41 42 43 44 45 46 47 48 49
  func inferShape() {
    // para.output.dim = para.input.dim
  }
  
  func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
    do {
      try kernel.compute(commandBuffer: buffer, param: para)
    } catch let error {
      throw error
    }
  }
L
update  
liuruilong 已提交
50
  
51
  func delogOutput() {
L
liuruilong 已提交
52
    print(" \(type) input: ")
L
liuruilong 已提交
53
    print(para.input.metalTexture.toTensor(dim: (n: para.input.padToFourDim[0], c: para.input.padToFourDim[1], h: para.input.padToFourDim[2], w: para.input.padToFourDim[3])).strideArray())
L
liuruilong 已提交
54 55 56 57 58
    
    print(" \(type) Alpha: ")
    let _: Float32? = para.alpha.buffer.logDesc(header: " alpha: ", stridable: false)
    
    print(" \(type) output: ")
L
liuruilong 已提交
59
    print(para.output.metalTexture.toTensor(dim: (n: para.output.padToFourDim[0], c: para.output.padToFourDim[1], h: para.output.padToFourDim[2], w: para.output.padToFourDim[3])).strideArray())
L
liuruilong 已提交
60 61
  }
  
L
add log  
liuruilong 已提交
62 63 64
//    print("softmax delog")
//    let _: P? = para.input.metalTexture.logDesc(header: "softmax input: ", stridable: false)
//    let _: P? = para.output.metalTexture.logDesc(header: "softmax output: ", stridable: false)
65
}