squared_l2_norm.h 3.1 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/memory/buffer.h"
18 19
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
S
sneaxiy 已提交
20 21

#if defined(__NVCC__) || defined(__HIPCC__)
22
#include "paddle/phi/kernels/primitive/functor_primitives.h"
S
sneaxiy 已提交
23 24 25 26 27 28 29 30
#ifdef __NVCC__
#include "cub/cub.cuh"
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#endif

31 32
namespace phi {
namespace funcs {
S
sneaxiy 已提交
33 34

template <typename T1, typename T2 = T1>
L
Leo Chen 已提交
35
void SquaredL2Norm(const phi::CPUContext& ctx,
36 37 38
                   const T1* x,
                   T2* y,
                   size_t numel,
39
                   paddle::memory::Buffer* buffer = nullptr) {
S
sneaxiy 已提交
40
  if (std::is_same<T1, T2>::value) {
41 42 43
    using EigenT = typename phi::EigenTensor<T1, 1>::Type;
    using ConstEigenT = typename phi::EigenTensor<T1, 1>::ConstType;
    using EigenDim = typename phi::EigenDim<1>::Type;
S
sneaxiy 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    ConstEigenT input(x, EigenDim(numel));
    EigenT output(reinterpret_cast<T1*>(y), EigenDim(1));
    output.device(*ctx.eigen_device()) = input.square().sum();
  } else {
    T2 ret = static_cast<T2>(0);
    for (size_t i = 0; i < numel; ++i) {
      auto tmp = static_cast<T2>(x[i]);
      ret += tmp * tmp;
    }
    *y = ret;
  }
}

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T1, typename T2 = T1>
59
void SquaredL2Norm(const phi::GPUContext& ctx,
60 61 62
                   const T1* x,
                   T2* y,
                   size_t numel,
63
                   paddle::memory::Buffer* buffer = nullptr) {
S
sneaxiy 已提交
64
  if (UNLIKELY(buffer == nullptr)) {
65
    paddle::memory::Buffer tmp_buffer(ctx.GetPlace());
S
sneaxiy 已提交
66 67 68
    return SquaredL2Norm(ctx, x, y, numel, &tmp_buffer);
  }

69
  using FunctorT = phi::kps::SquareFunctor<T1, T2>;
S
sneaxiy 已提交
70 71 72 73 74 75 76 77 78
  cub::TransformInputIterator<T2, FunctorT, const T1*> iter(x, FunctorT());
  size_t temp_storage_bytes = 0;
  void* d_temp_storage = nullptr;
  auto stream = ctx.stream();
#pragma unroll 2
  for (size_t i = 0; i < 2; ++i) {
    if (temp_storage_bytes > 0) {
      d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
    }
79 80 81 82 83 84 85
    PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceReduce::Reduce(d_temp_storage,
                                                         temp_storage_bytes,
                                                         iter,
                                                         y,
                                                         numel,
                                                         cub::Sum(),
                                                         static_cast<T2>(0)));
S
sneaxiy 已提交
86 87 88 89
  }
}
#endif

90 91
}  // namespace funcs
}  // namespace phi