label_smooth_kernel.cu 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/label_smooth_kernel.h"

namespace phi {

template <typename T>
struct LabelSmoothFunctor {
  T epsilon;
  T label_dim;

  __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
    epsilon = static_cast<T>(epsilon_data);
    label_dim = static_cast<T>(label_dim_data);
  }

  __device__ __forceinline__ T operator()(const T x) const {
    return (static_cast<T>(1 - epsilon) * x +
            static_cast<T>(epsilon / label_dim));
  }
};

template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N,
                                         const float epsilon,
                                         const int dist_numel,
                                         const T* src,
                                         const T* dist_data,
                                         T* dst) {
  CUDA_KERNEL_LOOP(idx, N) {
    int dist_idx = idx % dist_numel;
    dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
               static_cast<T>(epsilon) * dist_data[dist_idx];
  }
}

template <typename T, typename Context>
void LabelSmoothKernel(const Context& ctx,
                       const DenseTensor& label,
                       paddle::optional<const DenseTensor&> prior_dist,
                       float epsilon,
                       DenseTensor* out) {
  auto label_dim = label.dims()[label.dims().size() - 1];
  auto size_prob = label.numel();
  const T* in_data = label.data<T>();
  T* out_data = ctx.template Alloc<T>(out);

  if (prior_dist.get_ptr()) {
    int threads = 512;
    int grid = (size_prob + threads - 1) / threads;
    auto stream = ctx.stream();
    const auto* dist_t = prior_dist.get_ptr();
    auto dist_numel = dist_t->numel();
    const T* dist_data = dist_t->data<T>();
    LabelSmoothRunDistKernel<T><<<grid, threads, 0, stream>>>(
        size_prob, epsilon, dist_numel, in_data, dist_data, out_data);

  } else {
    std::vector<const DenseTensor*> ins = {&label};
    std::vector<DenseTensor*> outs = {out};
    auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
    paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
        ctx, ins, &outs, functor);
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(
    label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}