transpose_mkldnn_op.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/transpose_op.h"
19 20 21 22 23 24 25 26
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using framework::DataLayout;

27 28 29 30 31
template <typename T>
class TransposeMKLDNNHandler {
 public:
  TransposeMKLDNNHandler(std::vector<int64_t>& dims,  // NOLINT
                         std::vector<int>& axis,      // NOLINT
32
                         dnnl::engine engine)
33 34 35 36 37
      : dims_(dims),
        axis_(axis),
        logical_axis_(dims.size(), 0),
        engine_(engine) {}

38 39
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
                                                 void* ptr) {
40 41 42 43 44 45 46 47 48 49
    // Make memory descriptor using input format, unless it
    // cannot be trusted (nchw) then make up memory fmt manually
    for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
      this->logical_axis_[i] = i;
    }

    auto src_md = fmt != MKLDNNMemoryFormat::nchw
                      ? platform::MKLDNNMemDesc(
                            dims_, platform::MKLDNNGetDataType<T>(), fmt)
                      : Axis2MemoryDesc(dims_, logical_axis_);
50
    return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
51 52
  }

53 54
  std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output,
                                                 platform::Place place) {
55 56
    auto dst_md = Axis2MemoryDesc(dims_, axis_);
    auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
57
    return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
58 59
  }

60 61 62 63
  std::shared_ptr<dnnl::reorder> AcquireTranspose(
      std::shared_ptr<dnnl::memory> dst_memory_p,
      std::shared_ptr<dnnl::memory> src_memory_p) {
    return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
64 65 66
  }

 protected:
67 68
  dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz,  // NOLINT
                                     std::vector<int>& axis          // NOLINT
69
  ) {
70 71 72 73 74 75 76 77
    size_t ndims = axis.size();

    std::vector<int64_t> strides(ndims);
    unsigned int total_stride = 1;
    for (int i = ndims - 1; i >= 0; --i) {
      strides[axis[i]] = total_stride;
      total_stride *= nchw_tz[axis[i]];
    }
78 79
    dnnl::memory::desc mem_d(
        nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
80 81 82 83 84 85 86 87

    return mem_d;
  }

 private:
  std::vector<int64_t> dims_;
  std::vector<int> axis_;
  std::vector<int> logical_axis_;
88
  dnnl::engine engine_;
89 90
};

91 92 93 94
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
95 96
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
97 98
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Transpose must use CPUPlace"));
99 100 101 102 103 104 105 106 107 108
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
    std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
    int ndims = axis.size();
    auto* input = ctx.Input<Tensor>("X");
    auto* output = ctx.Output<Tensor>("Out");
    const T* input_data = input->data<T>();

    if (ndims == 1) {
109 110
      framework::TensorCopy(*input, input->place(), output);
      output->set_format(input->format());
111 112 113
      return;
    }

114
    auto nchw_tz = phi::vectorize<int64_t>(input->dims());
115

116
    TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
117

118
    auto transpose_src_memory_p = handler.AcquireSrcMemory(
119
        input->format(), platform::to_void_cast<T>(input_data));
120 121 122 123
    auto transpose_dst_memory_p =
        handler.AcquireDstMemory(output, ctx.GetPlace());
    auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
                                                transpose_src_memory_p);
124

125
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
126 127
    transpose_p->execute(
        astream, *transpose_src_memory_p, *transpose_dst_memory_p);
A
Adam 已提交
128
    astream.wait();
129

130
    output->set_layout(DataLayout::kNCHW);
A
Adam 已提交
131
    output->set_format(MKLDNNMemoryFormat::undef);
132 133 134
  }
};

135 136 137 138
template <typename T>
class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
139 140
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
141 142
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL TransposeGrad must use CPUPlace"));
143 144 145 146 147 148 149 150 151 152 153
    auto* out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    if (!x_grad) return;
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
    std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
    std::vector<int> reversed_axis(axis);
    int ndims = axis.size();
    if (ndims == 1) {
154 155
      framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
      x_grad->set_format(out_grad->format());
156 157 158 159 160 161 162 163 164 165
      return;
    }

    for (size_t i = 0; i < axis.size(); i++) {
      reversed_axis[axis[i]] = i;
    }

    const T* out_grad_data = out_grad->data<T>();
    x_grad->mutable_data<T>(ctx.GetPlace());

166
    auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims());
167

168
    TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
169

170 171
    auto transpose_src_memory_p = handler.AcquireSrcMemory(
        out_grad->format(), platform::to_void_cast<T>(out_grad_data));
172 173 174 175 176
    auto transpose_dst_memory_p =
        handler.AcquireDstMemory(x_grad, ctx.GetPlace());
    auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
                                                transpose_src_memory_p);

177
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
178 179
    transpose_p->execute(
        astream, *transpose_src_memory_p, *transpose_dst_memory_p);
A
Adam 已提交
180
    astream.wait();
181 182 183
  }
};

184 185 186 187 188
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

189 190 191 192
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    FP32,
193 194 195
                                    ops::kTransposeMKLDNNFP32,
                                    ops::TransposeMKLDNNOpKernel<float>);

196 197 198 199
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    U8,
200 201 202
                                    ops::kTransposeMKLDNNINT8,
                                    ops::TransposeMKLDNNOpKernel<uint8_t>);

203 204 205 206
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    S8,
207 208 209
                                    ops::kTransposeMKLDNNINT8,
                                    ops::TransposeMKLDNNOpKernel<int8_t>);

210
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
211 212 213 214
    transpose2,
    MKLDNN,
    ::paddle::platform::CPUPlace,
    BF16,
215 216 217
    ops::kTransposeMKLDNNFP32,
    ops::TransposeMKLDNNOpKernel<paddle::platform::bfloat16>);

218 219 220
REGISTER_OP_KERNEL(transpose,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
221
                   ops::TransposeMKLDNNOpKernel<float>);
222

223 224 225
REGISTER_OP_KERNEL(transpose_grad,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
226
                   ops::TransposeMKLDNNGradOpKernel<float>);