MKLDNNActivation.h 6.2 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "ActivationFunction.h"
#include "mkldnn.hpp"
#include "paddle/gserver/layers/MKLDNNBase.h"
#include "paddle/math/MKLDNNMatrix.h"
#include "paddle/parameter/Argument.h"

namespace paddle {

/**
 * @brief Base class of MKLDNN Activation.
 * Common activation function are provieded,
 * including mkldnn_relu, mkldnn_elu, mkldnn_tanh, mkldnn_softmax
 */
class MKLDNNActivation : public ActivationFunction {
protected:
  // input value element count
  size_t cnt_;
33 34 35
  // should not merge the resetBwd into resetFwd,
  // because the grad data would be changing before backward.
  bool needResetBwd_;
T
tensor-tang 已提交
36 37 38 39 40 41 42 43 44 45
  // mkldnn matrix, primitive, stream and pipeline
  MKLDNNMatrixPtr val_;
  MKLDNNMatrixPtr grad_;
  std::shared_ptr<MKLDNNStream> stream_;
  std::shared_ptr<mkldnn::primitive> fwd_;
  std::shared_ptr<mkldnn::primitive> bwd_;
  std::vector<mkldnn::primitive> pipelineFwd_;
  std::vector<mkldnn::primitive> pipelineBwd_;

public:
46
  MKLDNNActivation() : cnt_(0), needResetBwd_(true) {}
T
tensor-tang 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  ~MKLDNNActivation() {}
  static ActivationFunction* create(const std::string& type);
  static std::vector<std::string> getAllRegisteredTypes();
  virtual const std::string& getName() const = 0;
  virtual Error __must_check forward(Argument& act) = 0;
  virtual Error __must_check backward(Argument& act) = 0;
};

/**
 * @brief Base class of MKLDNN Eltwise Activation,
 * includes mkldnn_relu, mkldnn_elu and mkldnn_tanh.
 */
class MKLDNNEltwiseActivation : public MKLDNNActivation {
  typedef mkldnn::eltwise_forward eltwise_fwd;
  typedef mkldnn::eltwise_backward eltwise_bwd;

63 64 65 66 67 68 69 70
protected:
  // save the forward primitive desc, which can be used backward
  std::shared_ptr<eltwise_fwd::primitive_desc> fwdPD_;
  // eltwise_bwd need src input value
  MKLDNNMatrixPtr inVal_;
  // use for copy data
  std::shared_ptr<mkldnn::reorder> copyInVal_;

T
tensor-tang 已提交
71 72 73 74 75 76
public:
  MKLDNNEltwiseActivation() {}

  ~MKLDNNEltwiseActivation() {}

  virtual const std::string& getName() const = 0;
77 78 79

  // in common, the alpha of forward and backward should be equal.
  // but for relu, to avoid negative value, they should be opposite
T
tensor-tang 已提交
80
  virtual float getAlpha() const = 0;
81
  virtual float getBwdAlpha() const = 0;
T
tensor-tang 已提交
82
  virtual float getBeta() const { return 0.f; }
83 84 85 86 87 88 89 90 91 92 93 94
  virtual mkldnn::algorithm getAlgo(const std::string& type) const {
    if (type == "mkldnn_relu") {
      return mkldnn::algorithm::eltwise_relu;
    } else if (type == "mkldnn_tanh") {
      return mkldnn::algorithm::eltwise_tanh;
    } else if (type == "mkldnn_elu") {
      return mkldnn::algorithm::eltwise_elu;
    } else {
      LOG(FATAL) << "Unkown eltwise activation type: " << type;
    }
    return (mkldnn::algorithm)0;
  }
T
tensor-tang 已提交
95 96

  /**
97
   * reshape and reset the forward primitives
T
tensor-tang 已提交
98
   */
99
  void resetFwd(Argument& act) {
T
tensor-tang 已提交
100 101 102
    if (cnt_ == act.value->getElementCnt()) {
      return;
    }
T
tensor-tang 已提交
103
    VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
T
tensor-tang 已提交
104 105 106 107 108
    cnt_ = act.value->getElementCnt();
    stream_.reset(new MKLDNNStream());
    auto eng = CPUEngine::Instance().getEngine();

    // get algo setting
109
    mkldnn::algorithm algo = getAlgo(this->getName());
T
tensor-tang 已提交
110 111 112 113
    // note: alpha represents the NegativeSlope when used in relu.
    float alpha = getAlpha();
    float beta = getBeta();

114
    pipelineFwd_.clear();
T
tensor-tang 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
    if (val_ == nullptr) {
      int bs = act.getBatchSize();
      int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
      int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
      int ic = cnt_ / bs / ih / iw;
      CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
      val_ = MKLDNNMatrix::create(
          act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng);
      CHECK(val_);
    }
    auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
                                     algo,
                                     val_->getMemoryDesc(),
                                     alpha,
                                     beta);
131 132 133
    fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng));
    // use inplace for forward but save input value before submit
    inVal_ = val_;
134 135 136
    copyInVal_ = nullptr;
    if (act.grad && algo == mkldnn::algorithm::eltwise_tanh) {
      // tanh need save src input for backward
137 138 139 140 141 142
      inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
      copyInVal_ = std::make_shared<mkldnn::reorder>(*val_, *inVal_);
      CHECK(copyInVal_) << "should not be emptry";
      pipelineFwd_.push_back(*copyInVal_);
    }
    fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
T
tensor-tang 已提交
143
    pipelineFwd_.push_back(*fwd_);
144 145
    needResetBwd_ = true;
  }
T
tensor-tang 已提交
146

147 148 149 150 151 152
  /**
   * reset the backward primitives, can not merge into resetFwd as the grad data
   * would be changing before backward.
   */
  void resetBwd(Argument& act) {
    if (!needResetBwd_) {
T
tensor-tang 已提交
153 154
      return;
    }
T
tensor-tang 已提交
155
    VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
156 157 158 159
    needResetBwd_ = false;
    mkldnn::algorithm algo = getAlgo(this->getName());
    float alpha = getBwdAlpha();
    float beta = getBeta();
T
tensor-tang 已提交
160
    grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
161
    auto eng = CPUEngine::Instance().getEngine();
T
tensor-tang 已提交
162 163
    auto bwdDesc = eltwise_bwd::desc(
        algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
164 165 166
    auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
    CHECK(inVal_);
    bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
T
tensor-tang 已提交
167 168 169 170 171
    pipelineBwd_.clear();
    pipelineBwd_.push_back(*bwd_);
  }

  Error __must_check forward(Argument& act) {
172
    resetFwd(act);
T
tensor-tang 已提交
173 174 175 176 177
    stream_->submit(pipelineFwd_);
    return Error();
  }

  Error __must_check backward(Argument& act) {
178
    resetBwd(act);
T
tensor-tang 已提交
179 180 181 182 183 184
    stream_->submit(pipelineBwd_);
    return Error();
  }
};

}  // namespace paddle