transpose_mkldnn_op.cc 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/transpose_op.h"
19 20 21 22 23
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

24
using Tensor = phi::DenseTensor;
25 26
using framework::DataLayout;

27 28 29 30 31
template <typename T>
class TransposeMKLDNNHandler {
 public:
  TransposeMKLDNNHandler(std::vector<int64_t>& dims,  // NOLINT
                         std::vector<int>& axis,      // NOLINT
32
                         dnnl::engine engine)
33 34 35 36 37
      : dims_(dims),
        axis_(axis),
        logical_axis_(dims.size(), 0),
        engine_(engine) {}

38 39
  std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
                                                 void* ptr) {
40 41 42 43 44 45 46 47 48 49
    // Make memory descriptor using input format, unless it
    // cannot be trusted (nchw) then make up memory fmt manually
    for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
      this->logical_axis_[i] = i;
    }

    auto src_md = fmt != MKLDNNMemoryFormat::nchw
                      ? platform::MKLDNNMemDesc(
                            dims_, platform::MKLDNNGetDataType<T>(), fmt)
                      : Axis2MemoryDesc(dims_, logical_axis_);
50
    return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
51 52
  }

53
  std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor* output,
54
                                                 platform::Place place) {
55 56
    auto dst_md = Axis2MemoryDesc(dims_, axis_);
    auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
57
    return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
58 59
  }

60 61 62 63
  std::shared_ptr<dnnl::reorder> AcquireTranspose(
      std::shared_ptr<dnnl::memory> dst_memory_p,
      std::shared_ptr<dnnl::memory> src_memory_p) {
    return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
64 65 66
  }

 protected:
67 68
  dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz,  // NOLINT
                                     std::vector<int>& axis          // NOLINT
69
  ) {
70 71 72 73 74 75 76 77
    size_t ndims = axis.size();

    std::vector<int64_t> strides(ndims);
    unsigned int total_stride = 1;
    for (int i = ndims - 1; i >= 0; --i) {
      strides[axis[i]] = total_stride;
      total_stride *= nchw_tz[axis[i]];
    }
78 79
    dnnl::memory::desc mem_d(
        nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
80 81 82 83 84 85 86 87

    return mem_d;
  }

 private:
  std::vector<int64_t> dims_;
  std::vector<int> axis_;
  std::vector<int> logical_axis_;
88
  dnnl::engine engine_;
89 90
};

91 92 93 94
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
95 96
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
97 98
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Transpose must use CPUPlace"));
99 100 101 102 103
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
    std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
    int ndims = axis.size();
104 105
    auto* input = ctx.Input<phi::DenseTensor>("X");
    auto* output = ctx.Output<phi::DenseTensor>("Out");
106 107 108
    const T* input_data = input->data<T>();

    if (ndims == 1) {
109 110
      framework::TensorCopy(*input, input->place(), output);
      output->set_format(input->format());
111 112 113
      return;
    }

114
    auto nchw_tz = phi::vectorize<int64_t>(input->dims());
115

116
    TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
117

118
    auto transpose_src_memory_p = handler.AcquireSrcMemory(
119
        input->format(), platform::to_void_cast<T>(input_data));
120 121 122 123
    auto transpose_dst_memory_p =
        handler.AcquireDstMemory(output, ctx.GetPlace());
    auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
                                                transpose_src_memory_p);
124

125
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
126 127
    transpose_p->execute(
        astream, *transpose_src_memory_p, *transpose_dst_memory_p);
A
Adam 已提交
128
    astream.wait();
129

130
    output->set_layout(DataLayout::kNCHW);
A
Adam 已提交
131
    output->set_format(MKLDNNMemoryFormat::undef);
132 133 134
  }
};

135 136 137 138
template <typename T>
class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
139 140
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
141 142
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL TransposeGrad must use CPUPlace"));
143 144
    auto* out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto* x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
145 146 147 148 149 150 151 152
    if (!x_grad) return;
    auto& dev_ctx =
        ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
    std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
    std::vector<int> reversed_axis(axis);
    int ndims = axis.size();
    if (ndims == 1) {
153 154
      framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
      x_grad->set_format(out_grad->format());
155 156 157 158 159 160 161 162 163 164
      return;
    }

    for (size_t i = 0; i < axis.size(); i++) {
      reversed_axis[axis[i]] = i;
    }

    const T* out_grad_data = out_grad->data<T>();
    x_grad->mutable_data<T>(ctx.GetPlace());

165
    auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims());
166

167
    TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
168

169 170
    auto transpose_src_memory_p = handler.AcquireSrcMemory(
        out_grad->format(), platform::to_void_cast<T>(out_grad_data));
171 172 173 174 175
    auto transpose_dst_memory_p =
        handler.AcquireDstMemory(x_grad, ctx.GetPlace());
    auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
                                                transpose_src_memory_p);

176
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
177 178
    transpose_p->execute(
        astream, *transpose_src_memory_p, *transpose_dst_memory_p);
A
Adam 已提交
179
    astream.wait();
180 181 182
  }
};

183 184 185 186 187
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

188 189 190 191
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    FP32,
192 193 194
                                    ops::kTransposeMKLDNNFP32,
                                    ops::TransposeMKLDNNOpKernel<float>);

195 196 197 198
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    U8,
199 200 201
                                    ops::kTransposeMKLDNNINT8,
                                    ops::TransposeMKLDNNOpKernel<uint8_t>);

202 203 204 205
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
                                    MKLDNN,
                                    ::paddle::platform::CPUPlace,
                                    S8,
206 207 208
                                    ops::kTransposeMKLDNNINT8,
                                    ops::TransposeMKLDNNOpKernel<int8_t>);

209
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
210 211 212 213
    transpose2,
    MKLDNN,
    ::paddle::platform::CPUPlace,
    BF16,
214 215 216
    ops::kTransposeMKLDNNFP32,
    ops::TransposeMKLDNNOpKernel<paddle::platform::bfloat16>);

217 218 219
REGISTER_OP_KERNEL(transpose,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
220
                   ops::TransposeMKLDNNOpKernel<float>);
221

222 223 224
REGISTER_OP_KERNEL(transpose_grad,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
225
                   ops::TransposeMKLDNNGradOpKernel<float>);
226

227 228 229
REGISTER_OP_KERNEL(transpose2_grad,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
230
                   ops::TransposeMKLDNNGradOpKernel<float>);