transpose_kernel.cc 3.4 KB
Newer Older
1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void TransposeKernel(const Context& dev_ctx,
                     const DenseTensor& x,
                     const std::vector<int>& axis,
                     DenseTensor* out) {
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  // Here we need to match dims to paddle layout
  // as we are producing non-oneDNN result
  auto x_dims = x.dims();
  if ((x_dims.size() >= 3) &&
      (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
       phi::DataLayout::kNHWC)) {
    int axis_size = axis.size();
    std::vector<int> formated_axis = axis;
    std::vector<int> count(axis_size, 0);
    for (int i = 0; i < axis_size; i++) {
      if (axis[i] < 0) {
        formated_axis[i] = axis[i] + axis_size;
      }
    }
    auto dims = phi::vectorize<int>(x_dims);

    std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
    x_dims = x_dims.reshape(dims);
    VLOG(3)
        << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";

    phi::DDim out_dims(x_dims);
    for (size_t i = 0; i < axis.size(); i++) {
      out_dims[i] = x_dims[formated_axis[i]];
    }
    out->Resize(out_dims);
  }

54
  PADDLE_ENFORCE_EQ(
55 56
      dev_ctx.GetPlace().GetType(),
      AllocationType::CPU,
57 58
      errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace"));

59
  if (axis.size() == 1 || axis.size() == 0) {
60
    Copy<Context>(dev_ctx, x, x.place(), false, out);
61 62 63 64 65 66 67
    out->set_mem_desc(x.mem_desc());
    return;
  }

  auto x_vec_dims = vectorize(x.dims());
  auto x_type = funcs::ToOneDNNDataType(x.dtype());
  funcs::ReorderOneDNNHandler reorder_handler(
68
      x_vec_dims, x.dtype(), x_type, dev_ctx.GetEngine());
69 70 71
  auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
      x.mem_desc(), funcs::to_void_cast(x.data<T>()));

72
  auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis);
73 74
  auto dst_md =
      dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides);
75
  auto reorder_dst_memory_p =
76
      reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
77 78
  auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
                                                  reorder_src_memory_p);
79 80 81 82

  auto& astream = OneDNNContext::tls().get_stream();
  reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
  astream.wait();
83 84
  out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
      funcs::TransposeToPermuteAxes(axis)));
85 86 87 88 89 90 91 92 93 94 95
}
}  // namespace phi

PD_REGISTER_KERNEL(transpose,
                   OneDNN,
                   ONEDNN,
                   phi::TransposeKernel,
                   float,
                   uint8_t,
                   int8_t,
                   phi::dtype::bfloat16) {}