/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. Indicesou may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "cub/cub.cuh" #include "paddle/fluid/operators/p_norm_op.h" namespace paddle { namespace operators { template __device__ __forceinline__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } __device__ __forceinline__ float inline_abs(float x) { return abs(x); } __device__ __forceinline__ double inline_abs(double x) { return abs(x); } __device__ __forceinline__ int inline_sign(float x) { return sgn(x); } __device__ __forceinline__ int inline_sign(double x) { return sgn(x); } __device__ __forceinline__ float inline_pow(float base, float exponent) { return pow(base, exponent); } __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } template __global__ void Pnorm(const T* x, const int pre, const int axis_n, // dim in axis const int post, float porder, T* out_norm) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int num = pre * post; auto porder_t = static_cast(porder); auto porder_inv = static_cast(1.0 / porder); for (int i = blockIdx.x; i < num; i += gridDim.x) { int base = (i / post) * post * axis_n + (i % post); T sum = 0.0; __shared__ T norm; for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { const T x_ij = x[base + j * post]; sum += inline_pow(inline_abs(x_ij), porder_t); } T reduce_result = BlockReduce(temp_storage).Sum(sum); if (threadIdx.x == 0) { norm = inline_pow(reduce_result, porder_inv); out_norm[i] = norm; } __syncthreads(); } } template class PnormCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in_x = ctx.Input("X"); auto* out_norm = ctx.Output("Out"); const T* x = in_x->data(); T* norm = out_norm->mutable_data(ctx.GetPlace()); auto xdim = in_x->dims(); auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; GetDims(xdim, axis, &pre, &n, &post); auto& dev_ctx = ctx.cuda_device_context(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); Pnorm<<>>(x, pre, n, post, porder, norm); } }; template __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, const float porder, const int pre, const int axis_n, const int post, const T eps, T* x_grad) { // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) int num = pre * post; auto porder_grad = static_cast(porder - 1.0f); for (int i = blockIdx.x; i < num; i += gridDim.x) { __shared__ T pnorm_i; __shared__ T yout_i; auto base = (i / post) * post * axis_n + (i % post); if (threadIdx.x == 0) { pnorm_i = x_norm[i]; yout_i = y_grad[i]; } __syncthreads(); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { int index = base + j * post; const T x_ij = inline_abs(x[index]); x_grad[index] = inline_pow(x_ij, porder_grad) / (inline_pow(pnorm_i, porder_grad) + eps) * yout_i * inline_sign(x[index]); } } } template class PnormGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in_x = ctx.Input("X"); auto* in_norm = ctx.Input("Out"); auto* in_norm_dy = ctx.Input(framework::GradVarName("Out")); auto* out_dx = ctx.Output(framework::GradVarName("X")); T* dx = out_dx->mutable_data(ctx.GetPlace()); const T* x = in_x->data(); const T* x_norm = in_norm->data(); const T* norm_dy = in_norm_dy->data(); auto xdim = in_x->dims(); float porder = ctx.Attr("porder"); T eps = static_cast(ctx.Attr("epsilon")); int axis = ctx.Attr("axis"); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; GetDims(xdim, axis, &pre, &n, &post); auto& dev_ctx = ctx.cuda_device_context(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); PnormGradient<<>>( x, x_norm, norm_dy, porder, pre, n, post, eps, dx); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel, ops::PnormCUDAKernel); REGISTER_OP_CUDA_KERNEL(p_norm_grad, ops::PnormGradCUDAKernel, ops::PnormGradCUDAKernel);