MKLDNNActivation.h 6.1 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "ActivationFunction.h"
#include "mkldnn.hpp"
#include "paddle/gserver/layers/MKLDNNBase.h"
#include "paddle/math/MKLDNNMatrix.h"
#include "paddle/parameter/Argument.h"

namespace paddle {

/**
 * @brief Base class of MKLDNN Activation.
 * Common activation function are provieded,
 * including mkldnn_relu, mkldnn_elu, mkldnn_tanh, mkldnn_softmax
 */
class MKLDNNActivation : public ActivationFunction {
protected:
  // input value element count
  size_t cnt_;
33 34 35
  // should not merge the resetBwd into resetFwd,
  // because the grad data would be changing before backward.
  bool needResetBwd_;
T
tensor-tang 已提交
36 37 38 39 40 41 42 43 44 45
  // mkldnn matrix, primitive, stream and pipeline
  MKLDNNMatrixPtr val_;
  MKLDNNMatrixPtr grad_;
  std::shared_ptr<MKLDNNStream> stream_;
  std::shared_ptr<mkldnn::primitive> fwd_;
  std::shared_ptr<mkldnn::primitive> bwd_;
  std::vector<mkldnn::primitive> pipelineFwd_;
  std::vector<mkldnn::primitive> pipelineBwd_;

public:
46
  MKLDNNActivation() : cnt_(0), needResetBwd_(true) {}
T
tensor-tang 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  ~MKLDNNActivation() {}
  static ActivationFunction* create(const std::string& type);
  static std::vector<std::string> getAllRegisteredTypes();
  virtual const std::string& getName() const = 0;
  virtual Error __must_check forward(Argument& act) = 0;
  virtual Error __must_check backward(Argument& act) = 0;
};

/**
 * @brief Base class of MKLDNN Eltwise Activation,
 * includes mkldnn_relu, mkldnn_elu and mkldnn_tanh.
 */
class MKLDNNEltwiseActivation : public MKLDNNActivation {
  typedef mkldnn::eltwise_forward eltwise_fwd;
  typedef mkldnn::eltwise_backward eltwise_bwd;

63 64 65 66 67 68 69 70
protected:
  // save the forward primitive desc, which can be used backward
  std::shared_ptr<eltwise_fwd::primitive_desc> fwdPD_;
  // eltwise_bwd need src input value
  MKLDNNMatrixPtr inVal_;
  // use for copy data
  std::shared_ptr<mkldnn::reorder> copyInVal_;

T
tensor-tang 已提交
71 72 73 74 75 76
public:
  MKLDNNEltwiseActivation() {}

  ~MKLDNNEltwiseActivation() {}

  virtual const std::string& getName() const = 0;
77 78 79

  // in common, the alpha of forward and backward should be equal.
  // but for relu, to avoid negative value, they should be opposite
T
tensor-tang 已提交
80
  virtual float getAlpha() const = 0;
81
  virtual float getBwdAlpha() const = 0;
T
tensor-tang 已提交
82
  virtual float getBeta() const { return 0.f; }
83 84 85 86 87 88 89 90 91 92 93 94
  virtual mkldnn::algorithm getAlgo(const std::string& type) const {
    if (type == "mkldnn_relu") {
      return mkldnn::algorithm::eltwise_relu;
    } else if (type == "mkldnn_tanh") {
      return mkldnn::algorithm::eltwise_tanh;
    } else if (type == "mkldnn_elu") {
      return mkldnn::algorithm::eltwise_elu;
    } else {
      LOG(FATAL) << "Unkown eltwise activation type: " << type;
    }
    return (mkldnn::algorithm)0;
  }
T
tensor-tang 已提交
95 96

  /**
97
   * reshape and reset the forward primitives
T
tensor-tang 已提交
98
   */
99
  void resetFwd(Argument& act) {
T
tensor-tang 已提交
100 101 102 103 104 105 106 107
    if (cnt_ == act.value->getElementCnt()) {
      return;
    }
    cnt_ = act.value->getElementCnt();
    stream_.reset(new MKLDNNStream());
    auto eng = CPUEngine::Instance().getEngine();

    // get algo setting
108
    mkldnn::algorithm algo = getAlgo(this->getName());
T
tensor-tang 已提交
109 110 111 112 113
    // note: alpha represents the NegativeSlope when used in relu.
    float alpha = getAlpha();
    float beta = getBeta();

    /// forward
114
    pipelineFwd_.clear();
T
tensor-tang 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    val_ = std::dynamic_pointer_cast<MKLDNNMatrix>(act.value);
    if (val_ == nullptr) {
      int bs = act.getBatchSize();
      int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
      int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
      int ic = cnt_ / bs / ih / iw;
      CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
      val_ = MKLDNNMatrix::create(
          act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng);
      CHECK(val_);
    }
    auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
                                     algo,
                                     val_->getMemoryDesc(),
                                     alpha,
                                     beta);
131 132 133 134 135 136 137 138 139 140 141
    fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng));
    // use inplace for forward but save input value before submit
    inVal_ = val_;
    if (act.grad) {
      // only copy when need do backward
      inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
      copyInVal_ = std::make_shared<mkldnn::reorder>(*val_, *inVal_);
      CHECK(copyInVal_) << "should not be emptry";
      pipelineFwd_.push_back(*copyInVal_);
    }
    fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
T
tensor-tang 已提交
142
    pipelineFwd_.push_back(*fwd_);
143 144
    needResetBwd_ = true;
  }
T
tensor-tang 已提交
145

146 147 148 149 150 151
  /**
   * reset the backward primitives, can not merge into resetFwd as the grad data
   * would be changing before backward.
   */
  void resetBwd(Argument& act) {
    if (!needResetBwd_) {
T
tensor-tang 已提交
152 153
      return;
    }
154 155 156 157
    needResetBwd_ = false;
    mkldnn::algorithm algo = getAlgo(this->getName());
    float alpha = getBwdAlpha();
    float beta = getBeta();
T
tensor-tang 已提交
158
    grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
159
    auto eng = CPUEngine::Instance().getEngine();
T
tensor-tang 已提交
160 161
    auto bwdDesc = eltwise_bwd::desc(
        algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
162 163 164
    auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
    CHECK(inVal_);
    bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
T
tensor-tang 已提交
165 166 167 168 169
    pipelineBwd_.clear();
    pipelineBwd_.push_back(*bwd_);
  }

  Error __must_check forward(Argument& act) {
170
    resetFwd(act);
T
tensor-tang 已提交
171 172 173 174 175
    stream_->submit(pipelineFwd_);
    return Error();
  }

  Error __must_check backward(Argument& act) {
176
    resetBwd(act);
T
tensor-tang 已提交
177 178 179 180 181 182
    stream_->submit(pipelineBwd_);
    return Error();
  }
};

}  // namespace paddle