// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/dist_kernel.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/p_norm_kernel.h" namespace phi { #define FULL_MASK 0xffffffff template struct ZeroOrderFunctor { public: __device__ T operator()(const T& x, const T& y) const { return static_cast((x - y) != 0); } }; template struct OtherOrderFunctor { explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} __device__ T operator()(const T& x, const T& y) const { return static_cast(pow(abs(x - y), p_order_)); } private: T p_order_; }; template struct PowFunctor { explicit PowFunctor(const T& p_order) : p_order_(p_order) {} HOSTDEVICE inline T operator()(const T x) const { return static_cast(pow(x, p_order_)); } T p_order_; }; template __global__ void ReduceSumWithSubtract( const T* x, const T* y, T* out, int64_t N, Functor func) { T sum_val = 0; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); } __syncthreads(); sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = sum_val; } } template __global__ void ReduceMaxWithSubtract(const T* x, const T* y, T* out, int64_t N) { T max_val = -1e10f; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { max_val = max(max_val, abs(x[i] - y[i])); } __syncthreads(); max_val = phi::funcs::BlockReduceMax(max_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = max_val; } } template __global__ void ReduceMinWithSubtract(const T* x, const T* y, T* out, int64_t N) { T min_val = 1e10f; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { min_val = min(min_val, abs(x[i] - y[i])); } __syncthreads(); min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = min_val; } } template void DistKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, float p, DenseTensor* out) { DenseTensor intermediate; const T* x_ptr = x.data(); const T* y_ptr = y.data(); T* o_ptr = dev_ctx.template Alloc(out); auto stream = dev_ctx.stream(); auto xdim = x.dims(); if (xdim == y.dims()) { // same shape auto n = x.numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); intermediate.Resize(phi::make_ddim({config.block_per_grid.x})); T* i_ptr = dev_ctx.template Alloc(&intermediate); std::vector axis_dims = {static_cast(-1)}; std::vector reduce_axis = funcs::details::GetReduceDim(axis_dims, xdim.size(), true); if (p == 0) { ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == INFINITY) { ReduceMaxWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == -INFINITY) { ReduceMinWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { T p_order = static_cast(p); ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; std::vector outs = {out}; T p_order_ = static_cast(1. / p_order); phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, PowFunctor(p_order_)); } } else { auto t = Subtract(dev_ctx, x, y); PNormKernel(dev_ctx, t, p, -1, 1e-12, false, true, out); } } } // namespace phi PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}