transfer_layout_kernel.cc 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/transfer_layout_kernel.h"
16

17 18 19
#include <sstream>
#include <string>

20 21
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
22
#include "paddle/phi/core/visit_type.h"
23
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
24
#include "paddle/phi/kernels/funcs/math_function.h"
25
#include "paddle/phi/kernels/memcpy_kernel.h"
26
#ifdef PADDLE_WITH_MKLDNN
27
#include "paddle/phi/backends/onednn/onednn_helper.h"
28
#endif
29
namespace phi {
30 31 32 33 34

std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
  PADDLE_ENFORCE_NE(
      from,
      to,
35
      phi::errors::InvalidArgument(
36 37 38 39 40 41
          "Layout transform should transform between different layout."));
  if (from == DataLayout::NCHW && to == DataLayout::NHWC) {
    return {0, 2, 3, 1};
  } else if (from == DataLayout::NHWC && to == DataLayout::NCHW) {
    return {0, 3, 1, 2};
  } else {
42
    PADDLE_THROW(phi::errors::InvalidArgument("Unsupported layout transform."));
43 44 45 46 47 48 49 50
  }
}

template <typename T, typename Context>
void CastDataLayout(const Context& dev_ctx,
                    const DenseTensor& x,
                    const std::vector<int>& axis,
                    DenseTensor* out) {
51
  funcs::Transpose<Context, T, 4> trans4;
52 53 54 55
  trans4(dev_ctx, x, out, axis);
}

template <typename Context>
56 57 58 59
void TransferLayoutGeneral(const Context& dev_ctx,
                           const DenseTensor& x,
                           DataLayout dst_layout,
                           DenseTensor* out) {
60 61 62 63 64 65 66 67 68 69
  auto src_dim = x.dims();

  auto axis = GetAxis(x.layout(), dst_layout);

  std::vector<int64_t> dst_dim;
  dst_dim.resize(axis.size());
  for (size_t i = 0; i < axis.size(); i++) {
    dst_dim[i] = src_dim[axis[i]];
  }

70 71
  out->Resize(phi::make_ddim(dst_dim));
  dev_ctx.Alloc(out, x.dtype());
72 73 74 75 76 77

  PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] {
                       CastDataLayout<data_t, Context>(dev_ctx, x, axis, out);
                     }));
}

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
#ifdef PADDLE_WITH_MKLDNN
template <typename Context>
void TransferLayoutMKLDNN(const Context& dev_ctx,
                          const DenseTensor& x,
                          DataLayout src_layout,
                          DataLayout dst_layout,
                          DenseTensor* out) {
  auto print_tensor_meta = [](const DenseTensor& x) {
    std::ostringstream oss;

    oss << "[";
    oss << "layout:" << x.layout() << " ,";
    oss << "dims:" << x.dims() << " ,";
    if (x.IsInitialized()) oss << "place:" << x.place();
    oss << "]";

    return oss.str();
  };
  VLOG(10) << " x: " << print_tensor_meta(x);
  VLOG(10) << " out: " << print_tensor_meta(*out) << " " << out;

  // NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in
  // data_transfer.cc
101
  if (!x.IsInitialized() && src_layout == DataLayout::ONEDNN &&
102 103 104 105 106 107 108 109
      dst_layout == DataLayout::NHWC) {
    VLOG(4) << src_layout << "->" << dst_layout << " " << x.layout();
    out->Resize(x.dims());
    out->set_layout(dst_layout);
    funcs::MatchShapeToLayout(out, src_layout, dst_layout);
    return;
  }

110
  if (src_layout != DataLayout::ONEDNN && dst_layout == DataLayout::ONEDNN) {
111 112
    // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
    // Just set layout/format. No real transform occur
113 114
    auto out_format = funcs::OneDNNFormatForSize(
        x.dims().size(), funcs::ToOneDNNFormat(src_layout));
115 116 117 118 119 120 121 122 123 124

    out->ShareDataWith(x);
    // For NHWC data we need reshape of tensors as MKL-DNN
    // is expecting NHWC dims description order
    if (src_layout == DataLayout::NHWC) {
      VLOG(4) << "NHWC";
      funcs::MatchShapeToLayout(out, src_layout, dst_layout);
      OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
    }

125 126 127 128
    dnnl::memory::desc out_mem_desc(vectorize<int64_t>(out->dims()),
                                    funcs::ToOneDNNDataType(x.dtype()),
                                    out_format);
    out->set_mem_desc(out_mem_desc);
129 130
  } else if (src_layout == DataLayout::ONEDNN &&
             dst_layout != DataLayout::ONEDNN) {
131 132
    // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
    // Do transform via MKLDNN lib
133
    funcs::TransDataLayoutFromOneDNN(
134
        src_layout, dst_layout, x, out, dev_ctx.GetPlace());
135 136
  } else if (src_layout == DataLayout::ONEDNN &&
             dst_layout == DataLayout::ONEDNN) {
137 138 139 140
    PADDLE_ENFORCE_NE(
        src_layout,
        dst_layout,
        errors::PreconditionNotMet(
141
            "No layout transform needed between two oneDNN OPKernels."));
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  } else {
    TransferLayoutGeneral<Context>(dev_ctx, x, dst_layout, out);
  }
}
#endif

template <typename Context>
void TransferLayoutKernel(const Context& dev_ctx,
                          const DenseTensor& x,
                          int src_layout,
                          int dst_layout,
                          DenseTensor* out) {
  PADDLE_ENFORCE_NE(src_layout,
                    dst_layout,
                    errors::PreconditionNotMet(
                        "No layout transform needed between same layout."));
  VLOG(10) << "TransDataLayout from " << static_cast<DataLayout>(src_layout)
           << " -> " << static_cast<DataLayout>(dst_layout);

161 162 163 164 165 166 167
  VLOG_IF(10, x.initialized()) << "TransDataLayout from " << x.layout();
  if (x.layout() == static_cast<DataLayout>(dst_layout)) {
    VLOG(10) << "No need to transform, already is " << x.layout();
    Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
    return;
  }

168 169 170 171 172 173 174 175 176 177 178 179
#ifdef PADDLE_WITH_MKLDNN
  TransferLayoutMKLDNN<Context>(dev_ctx,
                                x,
                                static_cast<DataLayout>(src_layout),
                                static_cast<DataLayout>(dst_layout),
                                out);
#else
  TransferLayoutGeneral<Context>(
      dev_ctx, x, static_cast<DataLayout>(dst_layout), out);
#endif
}

180
}  // namespace phi
181

182
PD_REGISTER_GENERAL_KERNEL(transfer_layout,
183 184
                           CPU,
                           ALL_LAYOUT,
185
                           phi::TransferLayoutKernel<phi::CPUContext>,
186
                           ALL_DTYPE) {}
187 188 189 190 191 192 193
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(transfer_layout,
                           GPU,
                           ALL_LAYOUT,
                           phi::TransferLayoutKernel<phi::GPUContext>,
                           ALL_DTYPE) {}
#endif