layout_compute.cc 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/cuda/layout_compute.h"
16
#include <vector>
17 18 19 20 21 22 23 24
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
inline DDim trim_singular_dims(const DDim& dims) {
  auto actual_dims_size = dims.size();
  for (; actual_dims_size != 0; --actual_dims_size) {
    if (dims[actual_dims_size - 1] != 1) break;
  }
  std::vector<int64_t> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
  }
  if (trim_dims.size() == 0) {
    return DDim();
  }
  return DDim(trim_dims);
}

Z
Zhaolong Xing 已提交
41 42 43 44 45
#define NCHWTONHWC(type)                                                  \
  auto& param = this->template Param<param_t>();                          \
  auto& ctx = this->ctx_->template As<CUDAContext>();                     \
  auto input = param.x->template data<type>();                            \
  auto input_dim = param.x->dims();                                       \
46 47 48 49 50
  DDim input_trim_dim = trim_singular_dims(input_dim);                    \
  if (input_trim_dim.size() == 1) {                                       \
    param.y->CopyDataFrom(*param.x);                                      \
    return;                                                               \
  }                                                                       \
Z
Zhaolong Xing 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
  CHECK(input_dim.size() == 4)                                            \
      << "NCHW to NHWC should guarantee that the input dims should be 4"; \
  int n = input_dim[0];                                                   \
  int c = input_dim[1];                                                   \
  int h = input_dim[2];                                                   \
  int w = input_dim[3];                                                   \
  param.y->Resize({n, h, w, c});                                          \
  auto output = param.y->template mutable_data<type>(TARGET(kCUDA));      \
  lite::cuda::math::NCHW2NHWC<type>(n, c, h * w, input, output, &ctx);

#define NHWCTONCHW(type)                                                  \
  auto& param = this->template Param<param_t>();                          \
  auto& ctx = this->ctx_->template As<CUDAContext>();                     \
  auto input = param.x->template data<type>();                            \
  auto input_dim = param.x->dims();                                       \
66 67 68 69 70
  DDim input_trim_dim = trim_singular_dims(input_dim);                    \
  if (input_trim_dim.size() == 1) {                                       \
    param.y->CopyDataFrom(*param.x);                                      \
    return;                                                               \
  }                                                                       \
Z
Zhaolong Xing 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  CHECK(input_dim.size() == 4)                                            \
      << "NHWC to NCHW should guarantee that the input dims should be 4"; \
  int n = input_dim[0];                                                   \
  int h = input_dim[1];                                                   \
  int w = input_dim[2];                                                   \
  int c = input_dim[3];                                                   \
  param.y->Resize({n, c, h, w});                                          \
  auto output = param.y->template mutable_data<type>(TARGET(kCUDA));      \
  lite::cuda::math::NHWC2NCHW<type>(n, c, h * w, input, output, &ctx);

void NCHWToNHWCCompute::Run() { NCHWTONHWC(float) }

void NCHWToNHWCComputeInt8::Run() { NCHWTONHWC(int8_t) }

void NHWCToNCHWCompute::Run() { NHWCTONCHW(float) }

void NHWCToNCHWComputeInt8::Run() { NHWCTONCHW(int8_t) }
88 89 90 91 92 93 94 95 96 97

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(layout,
                     kCUDA,
                     kFloat,
                     kNCHW,
Z
Zhaolong Xing 已提交
98
                     paddle::lite::kernels::cuda::NCHWToNHWCCompute,
99
                     nchw2nhwc)
Z
Zhaolong Xing 已提交
100
    .BindInput("Input",
101 102 103 104 105 106 107 108 109 110 111 112
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNHWC))})
    .Finalize();

REGISTER_LITE_KERNEL(layout,
                     kCUDA,
                     kFloat,
Z
Zhaolong Xing 已提交
113 114
                     kNCHW,
                     paddle::lite::kernels::cuda::NHWCToNCHWCompute,
115
                     nhwc2nchw)
Z
Zhaolong Xing 已提交
116
    .BindInput("Input",
117 118 119 120 121 122 123 124 125 126 127 128 129
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(layout,
                     kCUDA,
                     kInt8,
                     kNCHW,
Z
Zhaolong Xing 已提交
130 131 132
                     paddle::lite::kernels::cuda::NCHWToNHWCComputeInt8,
                     int8_nchw2nhwc)
    .BindInput("Input",
133 134 135 136 137 138 139 140 141 142 143 144
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNHWC))})
    .Finalize();

REGISTER_LITE_KERNEL(layout,
                     kCUDA,
                     kInt8,
Z
Zhaolong Xing 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
                     kNCHW,
                     paddle::lite::kernels::cuda::NHWCToNCHWComputeInt8,
                     int8_nhwc2nchw)
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(layout_once,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::NCHWToNHWCCompute,
                     nchw2nhwc)
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNHWC))})
    .Finalize();

REGISTER_LITE_KERNEL(layout_once,
                     kCUDA,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::cuda::NHWCToNCHWCompute,
179
                     nhwc2nchw)
Z
Zhaolong Xing 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kFloat),
                                       DATALAYOUT(kNCHW))})
    .Finalize();

REGISTER_LITE_KERNEL(layout_once,
                     kCUDA,
                     kInt8,
                     kNCHW,
                     paddle::lite::kernels::cuda::NCHWToNHWCComputeInt8,
                     int8_nchw2nhwc)
    .BindInput("Input",
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNCHW))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNHWC))})
    .Finalize();

REGISTER_LITE_KERNEL(layout_once,
                     kCUDA,
                     kInt8,
                     kNCHW,
                     paddle::lite::kernels::cuda::NHWCToNCHWComputeInt8,
                     int8_nhwc2nchw)
    .BindInput("Input",
213 214 215 216 217 218 219 220
               {LiteType::GetTensorTy(TARGET(kCUDA),
                                      PRECISION(kInt8),
                                      DATALAYOUT(kNHWC))})
    .BindOutput("Out",
                {LiteType::GetTensorTy(TARGET(kCUDA),
                                       PRECISION(kInt8),
                                       DATALAYOUT(kNCHW))})
    .Finalize();