MKLDNNActivation.h 3.8 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "ActivationFunction.h"
#include "mkldnn.hpp"
X
Xin Pan 已提交
18
#include "paddle/legacy/gserver/layers/MKLDNNBase.h"
T
tensor-tang 已提交
19 20 21 22 23 24 25 26 27 28 29
#include "paddle/math/MKLDNNMatrix.h"
#include "paddle/parameter/Argument.h"

namespace paddle {

/**
 * @brief Base class of MKLDNN Activation.
 * Common activation function are provieded,
 * including mkldnn_relu, mkldnn_elu, mkldnn_tanh, mkldnn_softmax
 */
class MKLDNNActivation : public ActivationFunction {
W
Wu Yi 已提交
30
 protected:
T
tensor-tang 已提交
31 32
  // input value element count
  size_t cnt_;
33 34 35
  // should not merge the resetBwd into resetFwd,
  // because the grad data would be changing before backward.
  bool needResetBwd_;
T
tensor-tang 已提交
36 37 38
  // mkldnn matrix, primitive, stream and pipeline
  MKLDNNMatrixPtr val_;
  MKLDNNMatrixPtr grad_;
T
tensor-tang 已提交
39
  std::shared_ptr<mkldnn::engine> engine_;
T
tensor-tang 已提交
40 41 42 43 44 45
  std::shared_ptr<MKLDNNStream> stream_;
  std::shared_ptr<mkldnn::primitive> fwd_;
  std::shared_ptr<mkldnn::primitive> bwd_;
  std::vector<mkldnn::primitive> pipelineFwd_;
  std::vector<mkldnn::primitive> pipelineBwd_;

W
Wu Yi 已提交
46
 public:
47
  MKLDNNActivation() : cnt_(0), needResetBwd_(true) {}
T
tensor-tang 已提交
48 49 50 51
  ~MKLDNNActivation() {}
  static ActivationFunction* create(const std::string& type);
  static std::vector<std::string> getAllRegisteredTypes();
  virtual const std::string& getName() const = 0;
T
tensor-tang 已提交
52 53 54
  /**
   * reset the forward primitives
   */
T
tensor-tang 已提交
55
  virtual void resetFwd(Argument& act);
T
tensor-tang 已提交
56 57 58 59 60 61
  /**
   * reset the backward primitives,
   * can not merge this functions into resetFwd as the grad data
   * would be changing before backward.
   */
  virtual void resetBwd(Argument& act) {}
T
tensor-tang 已提交
62 63
  virtual Error __must_check forward(Argument& act);
  virtual Error __must_check backward(Argument& act);
T
tensor-tang 已提交
64 65 66 67 68 69 70 71 72
};

/**
 * @brief Base class of MKLDNN Eltwise Activation,
 * includes mkldnn_relu, mkldnn_elu and mkldnn_tanh.
 */
class MKLDNNEltwiseActivation : public MKLDNNActivation {
  typedef mkldnn::eltwise_forward eltwise_fwd;
  typedef mkldnn::eltwise_backward eltwise_bwd;
T
tensor-tang 已提交
73
  typedef mkldnn::algorithm algorithm;
T
tensor-tang 已提交
74

W
Wu Yi 已提交
75
 protected:
76 77 78 79 80 81 82
  // save the forward primitive desc, which can be used backward
  std::shared_ptr<eltwise_fwd::primitive_desc> fwdPD_;
  // eltwise_bwd need src input value
  MKLDNNMatrixPtr inVal_;
  // use for copy data
  std::shared_ptr<mkldnn::reorder> copyInVal_;

W
Wu Yi 已提交
83
 public:
T
tensor-tang 已提交
84 85 86
  MKLDNNEltwiseActivation() {}
  ~MKLDNNEltwiseActivation() {}
  virtual const std::string& getName() const = 0;
87 88 89

  // in common, the alpha of forward and backward should be equal.
  // but for relu, to avoid negative value, they should be opposite
T
tensor-tang 已提交
90
  virtual float getAlpha() const = 0;
91
  virtual float getBwdAlpha() const = 0;
T
tensor-tang 已提交
92
  virtual float getBeta() const { return 0.f; }
T
tensor-tang 已提交
93 94 95
  virtual algorithm getAlgo(std::string type) const;
  void resetFwd(Argument& act) override;
  void resetBwd(Argument& act) override;
T
tensor-tang 已提交
96
};
T
tensor-tang 已提交
97

T
tensor-tang 已提交
98 99 100 101 102 103 104
/**
 * @brief Base class of MKLDNN softmax Activation,
 * only have mkldnn forward, use cpu implement for backward.
 */
class MKLDNNSoftmaxActivation : public MKLDNNActivation {
  typedef mkldnn::softmax_forward softmax_fwd;

W
Wu Yi 已提交
105
 private:
T
tensor-tang 已提交
106 107 108 109
  // for backward
  MatrixPtr sftMaxSum_;
  MatrixPtr sftMaxDot_;

W
Wu Yi 已提交
110
 public:
T
tensor-tang 已提交
111 112 113
  MKLDNNSoftmaxActivation() {}
  ~MKLDNNSoftmaxActivation() {}
  virtual const std::string& getName() const = 0;
T
tensor-tang 已提交
114 115 116
  void resetFwd(Argument& act) override;
  Error __must_check forward(Argument& act) override;
  Error __must_check backward(Argument& act) override;
T
tensor-tang 已提交
117 118 119
};

}  // namespace paddle