scale_api.h 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "glog/logging.h"

19 20 21 22
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/allocator.h"
23
#include "paddle/phi/api/lib/utils/storage.h"
24 25 26 27 28 29
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/scale_kernel.h"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

namespace paddle {
namespace experimental {

PADDLE_API Tensor scale_kernel_context(const Tensor& x,
                                       const Scalar& scale,
                                       float bias,
                                       bool bias_after_scale) {
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;

  if (kernel_backend == Backend::UNDEFINED ||
      kernel_layout == DataLayout::UNDEFINED ||
      kernel_data_type == DataType::UNDEFINED) {
    auto kernel_key_set = ParseKernelKeyByInputArgs(x);
46
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
47 48 49 50 51 52 53 54 55 56
    if (kernel_backend == Backend::UNDEFINED) {
      kernel_backend = kernel_key.backend();
    }
    if (kernel_layout == DataLayout::UNDEFINED) {
      kernel_layout = kernel_key.layout();
    }
    if (kernel_data_type == DataType::UNDEFINED) {
      kernel_data_type = kernel_key.dtype();
    }
  }
57
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
58 59 60 61 62 63
      "scale", {kernel_backend, kernel_layout, kernel_data_type});
  VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
          << kernel_layout << ", " << kernel_data_type << "]";
  VLOG(6) << "scale API kernel: " << kernel;

  auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
64
  auto kernel_context = phi::KernelContext(dev_ctx);
65

66
  auto dense_x = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
67
  kernel_context.EmplaceBackInput(dense_x.get());
68

69
  kernel_context.EmplaceBackAttr(phi::Scalar(scale));
70 71 72
  kernel_context.EmplaceBackAttr(bias);
  kernel_context.EmplaceBackAttr(bias_after_scale);

73 74
  auto dense_out = std::make_shared<phi::DenseTensor>(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
75
          phi::TransToPhiPlace(kernel_backend)),
76 77 78
      phi::DenseTensorMeta());
  phi::MetaTensor meta_out(dense_out.get());
  phi::UnchangedInferMeta(*dense_x, &meta_out);
79
  kernel_context.EmplaceBackOutput(dense_out.get());
80 81 82 83 84 85 86 87 88

  Tensor out;
  out.set_impl(dense_out);

  kernel(&kernel_context);
  return out;
}

static void ScaleCPU(DataType kernel_dtype,
89 90
                     const phi::CPUContext& dev_ctx,
                     const phi::DenseTensor& x,
91 92 93
                     const Scalar& scale,
                     float bias,
                     bool bias_after_scale,
94
                     phi::DenseTensor* dense_out) {
95
  switch (kernel_dtype) {
96 97 98
    case phi::DataType::FLOAT64: {
      phi::ScaleKernel<double>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
99 100
      break;
    }
101 102 103
    case phi::DataType::FLOAT32: {
      phi::ScaleKernel<float>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
104 105
      break;
    }
106 107 108
    case phi::DataType::BFLOAT16: {
      phi::ScaleKernel<phi::dtype::bfloat16>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
109 110
      break;
    }
111 112 113
    case phi::DataType::INT64: {
      phi::ScaleKernel<int64_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
114 115
      break;
    }
116 117 118
    case phi::DataType::INT32: {
      phi::ScaleKernel<int32_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
119 120
      break;
    }
121 122 123
    case phi::DataType::INT16: {
      phi::ScaleKernel<int16_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
124 125
      break;
    }
126 127 128
    case phi::DataType::INT8: {
      phi::ScaleKernel<int8_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
129 130
      break;
    }
131 132 133
    case phi::DataType::UINT8: {
      phi::ScaleKernel<uint8_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
134 135 136 137 138 139 140 141 142 143 144 145 146
      break;
    }
    default: {
      PADDLE_THROW(paddle::platform::errors::Fatal(
          "Detected unsupported data type."
          "Only Float64, Float32, BFloat16, Int64, Int32, Int16, Int8, UInt8 "
          "are supported for now."));
      break;
    }
  }
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
147
static void ScaleGPU(DataType kernel_dtype,
148 149
                     const phi::GPUContext& dev_ctx,
                     const phi::DenseTensor& x,
150 151 152
                     const Scalar& scale,
                     float bias,
                     bool bias_after_scale,
153
                     phi::DenseTensor* dense_out) {
154
  switch (kernel_dtype) {
155 156 157
    case phi::DataType::FLOAT64: {
      phi::ScaleKernel<double>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
158 159
      break;
    }
160 161 162
    case phi::DataType::FLOAT32: {
      phi::ScaleKernel<float>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
163 164
      break;
    }
165 166 167
    case phi::DataType::FLOAT16: {
      phi::ScaleKernel<phi::dtype::float16>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
168 169
      break;
    }
170 171 172
    case phi::DataType::INT64: {
      phi::ScaleKernel<int64_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
173 174
      break;
    }
175 176 177
    case phi::DataType::INT32: {
      phi::ScaleKernel<int32_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
178 179
      break;
    }
180 181 182
    case phi::DataType::INT16: {
      phi::ScaleKernel<int16_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
183 184
      break;
    }
185 186 187
    case phi::DataType::INT8: {
      phi::ScaleKernel<int8_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
188 189
      break;
    }
190 191 192
    case phi::DataType::UINT8: {
      phi::ScaleKernel<uint8_t>(
          dev_ctx, x, phi::Scalar(scale), bias, bias_after_scale, dense_out);
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
      break;
    }
    default: {
      PADDLE_THROW(paddle::platform::errors::Fatal(
          "Detected unsupported data type."
          "Only Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt8 "
          "are "
          "supported for now."));
      break;
    }
  }
}
#endif

Tensor scale_switch_case(const Tensor& x,
                         const Scalar& scale,
                         float bias,
                         bool bias_after_scale) {
  Backend kernel_backend = Backend::UNDEFINED;
  DataLayout kernel_layout = DataLayout::UNDEFINED;
  DataType kernel_data_type = DataType::UNDEFINED;

  if (kernel_backend == Backend::UNDEFINED ||
      kernel_layout == DataLayout::UNDEFINED ||
      kernel_data_type == DataType::UNDEFINED) {
    auto kernel_key_set = ParseKernelKeyByInputArgs(x);
219
    auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
220 221 222 223 224 225 226 227 228 229
    if (kernel_backend == Backend::UNDEFINED) {
      kernel_backend = kernel_key.backend();
    }
    if (kernel_layout == DataLayout::UNDEFINED) {
      kernel_layout = kernel_key.layout();
    }
    if (kernel_data_type == DataType::UNDEFINED) {
      kernel_data_type = kernel_key.dtype();
    }
  }
230
  auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
231 232 233 234 235 236 237
      "scale", {kernel_backend, kernel_layout, kernel_data_type});
  VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
          << kernel_layout << ", " << kernel_data_type << "]";
  VLOG(6) << "scale API kernel: " << kernel;

  auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

238
  auto dense_x = std::dynamic_pointer_cast<phi::DenseTensor>(x.impl());
239

240 241
  auto dense_out = std::make_shared<phi::DenseTensor>(
      phi::make_intrusive<paddle::experimental::SharedStorage>(
242
          phi::TransToPhiPlace(kernel_backend)),
243 244 245
      phi::DenseTensorMeta());
  phi::MetaTensor meta_out(dense_out.get());
  phi::UnchangedInferMeta(*dense_x, &meta_out);
246 247 248 249 250 251 252

  Tensor out;
  out.set_impl(dense_out);

  switch (kernel_backend) {
    case Backend::CPU:
      ScaleCPU(kernel_data_type,
253
               static_cast<const phi::CPUContext&>(*dev_ctx),
254 255 256 257 258 259 260
               *dense_x,
               scale,
               bias,
               bias_after_scale,
               dense_out.get());
      break;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
261 262
    case Backend::GPU:
      ScaleGPU(kernel_data_type,
263
               static_cast<const phi::GPUContext&>(*dev_ctx),
264 265 266 267 268
               *dense_x,
               scale,
               bias,
               bias_after_scale,
               dense_out.get());
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
      break;
#endif
    default:
      PADDLE_THROW(paddle::platform::errors::Fatal(
          "Detected unsupported backend."
          "Only CPU and CUDA Backend are supported for now."
          "Please double check if your backend falls into the above two "
          "categories."));
  }

  return out;
}

}  // namespace experimental
}  // namespace paddle