diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 40476d5e11f6a3b0cad21038a3f342d824f3575c..18402d908c4ad8d67bf7fc980a9e5c8917beb142 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -20,9 +20,11 @@ namespace cub = hipcub; #endif #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" namespace paddle { namespace operators { @@ -42,71 +44,86 @@ static inline int NumBlocks(const int N) { } template -__global__ void GPUSigmoidForward(const T *x_data, const T *label_data, - const int ignore_index, const int limit, - T *out_data, T *counts) { - CUDA_KERNEL_LOOP(i, limit) { - T x = x_data[i]; - T label = label_data[i]; - T eps = static_cast(1e-5); - T diff = label - static_cast(ignore_index); +struct NonzeroFunctor { + HOSTDEVICE explicit inline NonzeroFunctor() {} + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(static_cast(x) != 0); + } +}; + +template +struct SigmoidFwdFunctor { + T ignore_index_; + T eps = static_cast(1e-5); + + HOSTDEVICE inline SigmoidFwdFunctor(const T ignore_index) + : ignore_index_(ignore_index) {} + + HOSTDEVICE inline phi::Array operator()(const T x, const T label) { + T counts; + T out_data; + + T diff = label - static_cast(ignore_index_); if ((diff > -eps) && (diff < eps)) { - out_data[i] = static_cast(0.); - counts[i] = 0; + out_data = static_cast(0.); + counts = 0; } else { T term1 = (x > 0) ? x : 0; T term2 = x * label; T term3 = real_log(static_cast(1) + real_exp(static_cast(-abs(x)))); - out_data[i] = term1 - term2 + term3; - counts[i] = 1; + + out_data = term1 - term2 + term3; + counts = 1; } - } -} + phi::Array outs; -template -__global__ void Sum(const T *counts, int num, const T eps, T *sum) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - T in = 0; - for (int i = threadIdx.x; i < num; i += BlockDim) { - in += counts[i]; + outs[0] = out_data; + outs[1] = counts; + return outs; } - __syncthreads(); - auto out = - BlockReduce(temp_storage).Reduce(static_cast(in), cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - T a = out > eps ? out : eps; - sum[0] = a; - } -} +}; template -__global__ void Div(T *loss, const int num, const T *norm) { - CUDA_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; } -} +struct SigmoidBwdFunctor { + T ignore_index_; + T eps = static_cast(1e-5); -template -__global__ void GPUSigmoidBackward(const T *x_data, const T *label_data, - const int ignore_index, const T *dout_data, - const int limit, T *dx_data, T *counts) { - CUDA_KERNEL_LOOP(i, limit) { - T x = x_data[i]; - T label = label_data[i]; - T dout = dout_data[i]; - T eps = static_cast(1e-5); - T diff = label - static_cast(ignore_index); + HOSTDEVICE inline SigmoidBwdFunctor(const T ignore_index) + : ignore_index_(ignore_index) {} + + HOSTDEVICE inline phi::Array operator()(const T x, const T label, + const T dout) { + T counts; + T dx_data; + + T diff = label - static_cast(ignore_index_); if ((diff > -eps) && (diff < eps)) { - dx_data[i] = static_cast(0.); - counts[i] = 0; + dx_data = static_cast(0.); + counts = 0; } else { T simoid_x = static_cast(1) / (static_cast(1) + real_exp(-x)); T diff = simoid_x - label; - dx_data[i] = dout * diff; - counts[i] = 1; + dx_data = dout * diff; + counts = 1; } + phi::Array outs; + + outs[0] = dx_data; + outs[1] = counts; + return outs; } -} +}; + +template +struct DivFunctor { + const T norm_; + HOSTDEVICE inline DivFunctor(const T norm) : norm_(norm) {} + + HOSTDEVICE inline T operator()(T loss) { + loss /= norm_; + return loss; + } +}; // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) template @@ -123,20 +140,48 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { bool normalize = context.Attr("normalize"); // Temporary memory - auto cnt_ptr = memory::Alloc(dev_ctx, Labels->numel() * sizeof(T)); - T *counts = reinterpret_cast(cnt_ptr->ptr()); - + Tensor *counts_tensor = new Tensor(); + counts_tensor->mutable_data(context.GetPlace(), + Labels->numel() * sizeof(T)); + counts_tensor->Resize(Out->dims()); int limit = Out->numel(); int blocks = NumBlocks(limit); int threads = kNumCUDAThreads; - GPUSigmoidForward<<>>( - X->data(), Labels->data(), ignore_index, limit, out_data, counts); + std::vector ins = {X, Labels}; + std::vector outs = {Out, counts_tensor}; + auto functor = SigmoidFwdFunctor(ignore_index); + constexpr int Size = 2; + phi::funcs::ElementwiseKernel(dev_ctx, ins, + &outs, functor); if (normalize) { - auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T)); - T *norm = reinterpret_cast(norm_ptr->ptr()); - Sum<<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>( - counts, limit, static_cast(1e-5), norm); - Div<<>>(out_data, limit, norm); + T *counts = counts_tensor->mutable_data(context.GetPlace()); + Tensor *norm_tensor = new Tensor(); + norm_tensor->mutable_data(context.GetPlace(), sizeof(T)); + auto dims = phi::vectorize(counts_tensor->dims()); + std::vector reduce_dim = {}; + for (int i = 0; i < dims.size(); i++) { + reduce_dim.push_back(i); + } + + TensorReduceImpl>( + context.cuda_device_context(), *counts_tensor, norm_tensor, + NonzeroFunctor(), reduce_dim, dev_ctx.stream()); + T *norm = norm_tensor->mutable_data(context.GetPlace()); + auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T)); + T *norm_cpu_ptr = reinterpret_cast(norm_cpu_mem->ptr()); + memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm, + sizeof(T), dev_ctx.stream()); + auto eps = static_cast(1e-5); + *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps; + + std::vector div_ins = {Out}; + std::vector div_outs = {Out}; + auto div_functor = DivFunctor(*norm_cpu_ptr); + phi::funcs::ElementwiseKernel(dev_ctx, div_ins, &div_outs, + div_functor); + + delete norm_tensor; + delete counts_tensor; } } }; @@ -157,22 +202,48 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel auto &dev_ctx = context.cuda_device_context(); // Temporary memory - auto cnt_ptr = memory::Alloc(dev_ctx, X->numel() * sizeof(T)); - T *counts = reinterpret_cast(cnt_ptr->ptr()); + Tensor *counts_tensor = new Tensor(); + counts_tensor->mutable_data(context.GetPlace(), + Labels->numel() * sizeof(T)); + counts_tensor->Resize(dX->dims()); int limit = dX->numel(); int blocks = NumBlocks(limit); int threads = kNumCUDAThreads; - GPUSigmoidBackward<<>>( - X->data(), Labels->data(), ignore_index, dOut->data(), limit, - dx_data, counts); + std::vector ins = {X, Labels, dOut}; + std::vector outs = {dX, counts_tensor}; + auto functor = SigmoidBwdFunctor(ignore_index); + constexpr int Size = 2; + phi::funcs::ElementwiseKernel(dev_ctx, ins, + &outs, functor); bool normalize = context.Attr("normalize"); if (normalize) { - auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T)); - T *norm = reinterpret_cast(norm_ptr->ptr()); - Sum<<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>( - counts, limit, static_cast(1e-5), norm); - Div<<>>(dx_data, limit, norm); + T *counts = counts_tensor->mutable_data(context.GetPlace()); + Tensor *norm_tensor = new Tensor(); + norm_tensor->mutable_data(context.GetPlace(), sizeof(T)); + auto dims = phi::vectorize(counts_tensor->dims()); + std::vector reduce_dim = {}; + for (int i = 0; i < dims.size(); i++) { + reduce_dim.push_back(i); + } + + TensorReduceImpl>( + context.cuda_device_context(), *counts_tensor, norm_tensor, + NonzeroFunctor(), reduce_dim, dev_ctx.stream()); + T *norm = norm_tensor->mutable_data(context.GetPlace()); + auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T)); + T *norm_cpu_ptr = reinterpret_cast(norm_cpu_mem->ptr()); + memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm, + sizeof(T), dev_ctx.stream()); + auto eps = static_cast(1e-5); + *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps; + + std::vector div_ins = {dX}; + std::vector div_outs = {dX}; + auto div_functor = DivFunctor(*norm_cpu_ptr); + phi::funcs::ElementwiseKernel(dev_ctx, div_ins, &div_outs, + div_functor); + delete norm_tensor; } } };