label_smooth_kernel.cu 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/label_smooth_kernel.h"

17
#include <vector>
18

19
#include "paddle/phi/backends/gpu/gpu_context.h"
20
#include "paddle/phi/common/amp_type_traits.h"
21
#include "paddle/phi/core/kernel_registry.h"
22
#include "paddle/phi/kernels/funcs/elementwise_base.h"
23 24 25 26
namespace phi {

template <typename T>
struct LabelSmoothFunctor {
27 28 29
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType epsilon;
  MPType label_dim;
30 31

  __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
32 33
    epsilon = static_cast<MPType>(epsilon_data);
    label_dim = static_cast<MPType>(label_dim_data);
34 35 36
  }

  __device__ __forceinline__ T operator()(const T x) const {
37 38 39 40
    return static_cast<T>(
        static_cast<MPType>(static_cast<MPType>(1) - epsilon) *
            static_cast<MPType>(x) +
        static_cast<MPType>(epsilon / label_dim));
41 42 43 44 45 46 47 48 49 50
  }
};

template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N,
                                         const float epsilon,
                                         const int dist_numel,
                                         const T* src,
                                         const T* dist_data,
                                         T* dst) {
51
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
52 53
  CUDA_KERNEL_LOOP(idx, N) {
    int dist_idx = idx % dist_numel;
54 55 56 57 58
    dst[idx] =
        static_cast<T>((static_cast<MPType>(1) - static_cast<MPType>(epsilon)) *
                           static_cast<MPType>(src[idx]) +
                       static_cast<MPType>(epsilon) *
                           static_cast<MPType>(dist_data[dist_idx]));
59 60 61 62 63 64
  }
}

template <typename T, typename Context>
void LabelSmoothKernel(const Context& ctx,
                       const DenseTensor& label,
65
                       const paddle::optional<DenseTensor>& prior_dist,
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                       float epsilon,
                       DenseTensor* out) {
  auto label_dim = label.dims()[label.dims().size() - 1];
  auto size_prob = label.numel();
  const T* in_data = label.data<T>();
  T* out_data = ctx.template Alloc<T>(out);

  if (prior_dist.get_ptr()) {
    int threads = 512;
    int grid = (size_prob + threads - 1) / threads;
    auto stream = ctx.stream();
    const auto* dist_t = prior_dist.get_ptr();
    auto dist_numel = dist_t->numel();
    const T* dist_data = dist_t->data<T>();
    LabelSmoothRunDistKernel<T><<<grid, threads, 0, stream>>>(
        size_prob, epsilon, dist_numel, in_data, dist_data, out_data);

  } else {
    std::vector<const DenseTensor*> ins = {&label};
    std::vector<DenseTensor*> outs = {out};
    auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
87
    phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
88 89 90 91 92
  }
}

}  // namespace phi

93 94 95 96 97 98
PD_REGISTER_KERNEL(label_smooth,
                   GPU,
                   ALL_LAYOUT,
                   phi::LabelSmoothKernel,
                   float,
                   double,
99 100
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}