未验证 提交 5d9e11a4 编写于 作者: H huangxu96 提交者: GitHub

Modified sigmoid by the elementwise interface. (#39898)

* Modified sigmoid by elementwise interface.

* using TensorReduceImpl to repalce Sum function

* using reduceimpl to calculate the norm variable

* Removed useless code
上级 3e56e816
......@@ -20,9 +20,11 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace paddle {
namespace operators {
......@@ -42,71 +44,86 @@ static inline int NumBlocks(const int N) {
}
template <typename T>
__global__ void GPUSigmoidForward(const T *x_data, const T *label_data,
const int ignore_index, const int limit,
T *out_data, T *counts) {
CUDA_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T label = label_data[i];
T eps = static_cast<T>(1e-5);
T diff = label - static_cast<T>(ignore_index);
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(static_cast<double>(x) != 0);
}
};
template <typename T>
struct SigmoidFwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidFwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label) {
T counts;
T out_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
out_data[i] = static_cast<T>(0.);
counts[i] = 0;
out_data = static_cast<T>(0.);
counts = 0;
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = real_log(static_cast<T>(1) + real_exp(static_cast<T>(-abs(x))));
out_data[i] = term1 - term2 + term3;
counts[i] = 1;
out_data = term1 - term2 + term3;
counts = 1;
}
}
}
phi::Array<T, 2> outs;
template <typename T, int BlockDim>
__global__ void Sum(const T *counts, int num, const T eps, T *sum) {
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
T in = 0;
for (int i = threadIdx.x; i < num; i += BlockDim) {
in += counts[i];
outs[0] = out_data;
outs[1] = counts;
return outs;
}
__syncthreads();
auto out =
BlockReduce(temp_storage).Reduce(static_cast<double>(in), cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
T a = out > eps ? out : eps;
sum[0] = a;
}
}
};
template <typename T>
__global__ void Div(T *loss, const int num, const T *norm) {
CUDA_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; }
}
struct SigmoidBwdFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
template <typename T>
__global__ void GPUSigmoidBackward(const T *x_data, const T *label_data,
const int ignore_index, const T *dout_data,
const int limit, T *dx_data, T *counts) {
CUDA_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T label = label_data[i];
T dout = dout_data[i];
T eps = static_cast<T>(1e-5);
T diff = label - static_cast<T>(ignore_index);
HOSTDEVICE inline SigmoidBwdFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x, const T label,
const T dout) {
T counts;
T dx_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data[i] = static_cast<T>(0.);
counts[i] = 0;
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + real_exp(-x));
T diff = simoid_x - label;
dx_data[i] = dout * diff;
counts[i] = 1;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = dx_data;
outs[1] = counts;
return outs;
}
}
};
template <typename T>
struct DivFunctor {
const T norm_;
HOSTDEVICE inline DivFunctor(const T norm) : norm_(norm) {}
HOSTDEVICE inline T operator()(T loss) {
loss /= norm_;
return loss;
}
};
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
......@@ -123,20 +140,48 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
bool normalize = context.Attr<bool>("normalize");
// Temporary memory
auto cnt_ptr = memory::Alloc(dev_ctx, Labels->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
Tensor *counts_tensor = new Tensor();
counts_tensor->mutable_data<T>(context.GetPlace(),
Labels->numel() * sizeof(T));
counts_tensor->Resize(Out->dims());
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<T>(), ignore_index, limit, out_data, counts);
std::vector<const framework::Tensor *> ins = {X, Labels};
std::vector<framework::Tensor *> outs = {Out, counts_tensor};
auto functor = SigmoidFwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
&outs, functor);
if (normalize) {
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_data, limit, norm);
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
Tensor *norm_tensor = new Tensor();
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
context.cuda_device_context(), *counts_tensor, norm_tensor,
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
sizeof(T), dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const framework::Tensor *> div_ins = {Out};
std::vector<framework::Tensor *> div_outs = {Out};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
div_functor);
delete norm_tensor;
delete counts_tensor;
}
}
};
......@@ -157,22 +202,48 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel
auto &dev_ctx = context.cuda_device_context();
// Temporary memory
auto cnt_ptr = memory::Alloc(dev_ctx, X->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
Tensor *counts_tensor = new Tensor();
counts_tensor->mutable_data<T>(context.GetPlace(),
Labels->numel() * sizeof(T));
counts_tensor->Resize(dX->dims());
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<T>(), ignore_index, dOut->data<T>(), limit,
dx_data, counts);
std::vector<const framework::Tensor *> ins = {X, Labels, dOut};
std::vector<framework::Tensor *> outs = {dX, counts_tensor};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
constexpr int Size = 2;
phi::funcs::ElementwiseKernel<T, decltype(functor), Size>(dev_ctx, ins,
&outs, functor);
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(dx_data, limit, norm);
T *counts = counts_tensor->mutable_data<T>(context.GetPlace());
Tensor *norm_tensor = new Tensor();
norm_tensor->mutable_data<T>(context.GetPlace(), sizeof(T));
auto dims = phi::vectorize(counts_tensor->dims());
std::vector<int> reduce_dim = {};
for (int i = 0; i < dims.size(); i++) {
reduce_dim.push_back(i);
}
TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
context.cuda_device_context(), *counts_tensor, norm_tensor,
NonzeroFunctor<T>(), reduce_dim, dev_ctx.stream());
T *norm = norm_tensor->mutable_data<T>(context.GetPlace());
auto norm_cpu_mem = memory::Alloc(platform::CPUPlace(), sizeof(T));
T *norm_cpu_ptr = reinterpret_cast<T *>(norm_cpu_mem->ptr());
memory::Copy(platform::CPUPlace(), norm_cpu_ptr, dev_ctx.GetPlace(), norm,
sizeof(T), dev_ctx.stream());
auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
std::vector<const framework::Tensor *> div_ins = {dX};
std::vector<framework::Tensor *> div_outs = {dX};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs,
div_functor);
delete norm_tensor;
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册