MKLDNNActivation.h 3.8 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "ActivationFunction.h"
#include "mkldnn.hpp"
#include "paddle/gserver/layers/MKLDNNBase.h"
#include "paddle/math/MKLDNNMatrix.h"
#include "paddle/parameter/Argument.h"

namespace paddle {

/**
 * @brief Base class of MKLDNN Activation.
 * Common activation function are provieded,
 * including mkldnn_relu, mkldnn_elu, mkldnn_tanh, mkldnn_softmax
 */
class MKLDNNActivation : public ActivationFunction {
protected:
  // input value element count
  size_t cnt_;
33 34 35
  // should not merge the resetBwd into resetFwd,
  // because the grad data would be changing before backward.
  bool needResetBwd_;
T
tensor-tang 已提交
36 37 38
  // mkldnn matrix, primitive, stream and pipeline
  MKLDNNMatrixPtr val_;
  MKLDNNMatrixPtr grad_;
T
tensor-tang 已提交
39
  std::shared_ptr<mkldnn::engine> engine_;
T
tensor-tang 已提交
40 41 42 43 44 45 46
  std::shared_ptr<MKLDNNStream> stream_;
  std::shared_ptr<mkldnn::primitive> fwd_;
  std::shared_ptr<mkldnn::primitive> bwd_;
  std::vector<mkldnn::primitive> pipelineFwd_;
  std::vector<mkldnn::primitive> pipelineBwd_;

public:
47
  MKLDNNActivation() : cnt_(0), needResetBwd_(true) {}
T
tensor-tang 已提交
48 49 50 51
  ~MKLDNNActivation() {}
  static ActivationFunction* create(const std::string& type);
  static std::vector<std::string> getAllRegisteredTypes();
  virtual const std::string& getName() const = 0;
T
tensor-tang 已提交
52 53 54
  /**
   * reset the forward primitives
   */
T
tensor-tang 已提交
55
  virtual void resetFwd(Argument& act);
T
tensor-tang 已提交
56 57 58 59 60 61
  /**
   * reset the backward primitives,
   * can not merge this functions into resetFwd as the grad data
   * would be changing before backward.
   */
  virtual void resetBwd(Argument& act) {}
T
tensor-tang 已提交
62 63
  virtual Error __must_check forward(Argument& act);
  virtual Error __must_check backward(Argument& act);
T
tensor-tang 已提交
64 65 66 67 68 69 70 71 72
};

/**
 * @brief Base class of MKLDNN Eltwise Activation,
 * includes mkldnn_relu, mkldnn_elu and mkldnn_tanh.
 */
class MKLDNNEltwiseActivation : public MKLDNNActivation {
  typedef mkldnn::eltwise_forward eltwise_fwd;
  typedef mkldnn::eltwise_backward eltwise_bwd;
T
tensor-tang 已提交
73
  typedef mkldnn::algorithm algorithm;
T
tensor-tang 已提交
74

75 76 77 78 79 80 81 82
protected:
  // save the forward primitive desc, which can be used backward
  std::shared_ptr<eltwise_fwd::primitive_desc> fwdPD_;
  // eltwise_bwd need src input value
  MKLDNNMatrixPtr inVal_;
  // use for copy data
  std::shared_ptr<mkldnn::reorder> copyInVal_;

T
tensor-tang 已提交
83 84 85 86
public:
  MKLDNNEltwiseActivation() {}
  ~MKLDNNEltwiseActivation() {}
  virtual const std::string& getName() const = 0;
87 88 89

  // in common, the alpha of forward and backward should be equal.
  // but for relu, to avoid negative value, they should be opposite
T
tensor-tang 已提交
90
  virtual float getAlpha() const = 0;
91
  virtual float getBwdAlpha() const = 0;
T
tensor-tang 已提交
92
  virtual float getBeta() const { return 0.f; }
T
tensor-tang 已提交
93 94 95
  virtual algorithm getAlgo(std::string type) const;
  void resetFwd(Argument& act) override;
  void resetBwd(Argument& act) override;
T
tensor-tang 已提交
96
};
T
tensor-tang 已提交
97

T
tensor-tang 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
/**
 * @brief Base class of MKLDNN softmax Activation,
 * only have mkldnn forward, use cpu implement for backward.
 */
class MKLDNNSoftmaxActivation : public MKLDNNActivation {
  typedef mkldnn::softmax_forward softmax_fwd;

private:
  // for backward
  MatrixPtr sftMaxSum_;
  MatrixPtr sftMaxDot_;

public:
  MKLDNNSoftmaxActivation() {}
  ~MKLDNNSoftmaxActivation() {}
  virtual const std::string& getName() const = 0;
T
tensor-tang 已提交
114 115 116
  void resetFwd(Argument& act) override;
  Error __must_check forward(Argument& act) override;
  Error __must_check backward(Argument& act) override;
T
tensor-tang 已提交
117 118 119
};

}  // namespace paddle