MKLDNNMatrix.h 4.2 KB
Newer Older
T
tensor-tang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

T
tensor-tang 已提交
17 18
#include <vector>
#include "Matrix.h"
T
tensor-tang 已提交
19 20 21 22 23 24 25 26 27 28 29 30
#include "mkldnn.hpp"
#include "paddle/parameter/Parameter.h"

namespace paddle {

class MKLDNNMatrix;
typedef std::shared_ptr<MKLDNNMatrix> MKLDNNMatrixPtr;

/**
 * @brief MKLDNN Matrix.
 *
 */
T
tensor-tang 已提交
31
class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
T
tensor-tang 已提交
32
public:
33 34 35 36
  MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd)
      : CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false),
        mkldnn::memory(pd, m->getData()),
        m_(m) {}
T
tensor-tang 已提交
37

T
tensor-tang 已提交
38 39
  ~MKLDNNMatrix() {}

40 41 42 43 44 45 46 47
  /**
   * Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc
   */
  static MKLDNNMatrixPtr create(MatrixPtr m, mkldnn::memory::primitive_desc pd);

  /**
   * Create MKLDNNMatrix from a MatrixPtr and memory details info
   */
T
tensor-tang 已提交
48
  static MKLDNNMatrixPtr create(
49
      MatrixPtr m,
T
tensor-tang 已提交
50 51 52 53 54
      mkldnn::memory::dims dims,
      mkldnn::memory::format fmt,
      mkldnn::engine& eg,
      mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32);

55
public:
T
tensor-tang 已提交
56 57
  /**
   * Reorder this MKLDNNMatrix from other format.
T
refine  
tensor-tang 已提交
58 59 60
   * Support inplace reorder.
   * @note: this function would only reorder the data layout.
   *        will NOT change this original dim or format info
T
tensor-tang 已提交
61 62 63 64 65 66 67
   */
  void reorderDataFrom(const MKLDNNMatrixPtr& m,
                       memory::format srcFmt,
                       memory::dims targetDim);

  /**
   * Reorder this MKLDNNMatrix to other format.
T
refine  
tensor-tang 已提交
68 69 70
   * Support inplace reorder.
   * @note: this function would only reorder the data layout.
   *        will NOT change the dst dim or format info
T
tensor-tang 已提交
71 72 73 74 75
   */
  void reorderDataTo(const MKLDNNMatrixPtr& m,
                     memory::format dstFmt,
                     memory::dims targetDim);

76 77 78 79 80 81 82
  /**
   * Dimensionality reduction.
   * Change format "nchw --> nc" or "oihw --> oi" if the h and w are both 1
   */
  void downSpatial();

  /**
83
   * set the memory data handle.
84 85 86
   * Caution: This will not check the buffer size of the data,
   *          it should be coverd by user.
   */
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  void setData(real* data) {
    set_data_handle(data);
    CpuMatrix::setData(data);
    m_.reset();
  }

  /**
   * override Matrix::getData
   * check data before return
   */
  real* getData() override {
    CHECK_EQ((void*)data_, get_data_handle());
    return data_;
  }

  const real* getData() const override {
    CHECK_EQ((void*)data_, get_data_handle());
    return data_;
  }
106

T
tensor-tang 已提交
107
  /**
T
tensor-tang 已提交
108
   * Get primitive descriptor.
T
tensor-tang 已提交
109
   */
T
refine  
tensor-tang 已提交
110 111 112
  mkldnn::memory::primitive_desc getPrimitiveDesc() {
    return this->get_primitive_desc();
  }
T
tensor-tang 已提交
113

T
tensor-tang 已提交
114
  /**
T
tensor-tang 已提交
115
   * Get memory descriptor.
T
tensor-tang 已提交
116
   */
T
refine  
tensor-tang 已提交
117
  mkldnn::memory::desc getMemoryDesc() { return getPrimitiveDesc().desc(); }
T
tensor-tang 已提交
118 119

  /**
120
   * Get dimensions.
T
tensor-tang 已提交
121
   */
T
tensor-tang 已提交
122
  mkldnn::memory::dims getDims() {
T
refine  
tensor-tang 已提交
123
    mkldnn::memory::desc md = getMemoryDesc();
124 125
    const int* src = md.data.dims;
    int ndims = md.data.ndims;
T
tensor-tang 已提交
126 127 128 129 130 131 132
    mkldnn::memory::dims dst;
    dst.resize(ndims);
    for (int i = 0; i < ndims; ++i) {
      dst[i] = src[i];
    }
    return dst;
  }
T
tensor-tang 已提交
133

T
tensor-tang 已提交
134 135 136 137
  /**
   * Get format.
   */
  mkldnn::memory::format getFormat() {
T
refine  
tensor-tang 已提交
138
    return (mkldnn::memory::format)(getMemoryDesc().data.format);
T
tensor-tang 已提交
139 140 141
  }

  /**
142
   * Get memory data type.
T
tensor-tang 已提交
143
   */
144
  mkldnn::memory::data_type getDtype() {
T
refine  
tensor-tang 已提交
145
    return (mkldnn::memory::data_type)(getMemoryDesc().data.data_type);
146 147 148 149 150
  }

  /**
   * Get engine.
   */
T
refine  
tensor-tang 已提交
151
  mkldnn::engine getEngine() { return getPrimitiveDesc().get_engine(); }
T
tensor-tang 已提交
152 153 154

protected:
  /**
T
refine  
tensor-tang 已提交
155 156
   * Do reorder once.
   * Can support inplace.
T
tensor-tang 已提交
157 158 159 160 161 162
   */
  void reorderOnce(void* srcData,
                   void* dstData,
                   memory::format srcFmt,
                   memory::format dstFmt,
                   memory::dims dm);
163 164 165 166

private:
  // save the CpuMatrixPtr in case the buffer released outside
  CpuMatrixPtr m_;
T
tensor-tang 已提交
167 168 169
};

}  // namespace paddle