未验证 提交 5970871a 编写于 作者: Z Zhaolong Xing 提交者: GitHub

add eltwise clip cuda impl. (#25689)

test=develop
上级 36027490
...@@ -25,17 +25,23 @@ namespace operators { ...@@ -25,17 +25,23 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
using platform::Transform; using platform::Transform;
#ifdef __NVCC__
template <typename T, typename UnaryOperation>
__global__ void ClipCudaKernel(const T* input, T* out, int num,
UnaryOperation op) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < num) {
out[idx] = op(input[idx]);
}
}
#endif
template <typename T> template <typename T>
class ClipFunctor { class ClipFunctor {
public: public:
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const { HOSTDEVICE T operator()(const T& x) const {
if (x < min_) return x < min_ ? min_ : x > max_ ? max_ : x;
return min_;
else if (x > max_)
return max_;
else
return x;
} }
private: private:
...@@ -97,9 +103,20 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -97,9 +103,20 @@ class ClipKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int64_t numel = x->numel(); int64_t numel = x->numel();
Transform<DeviceContext> trans; if (platform::is_gpu_place(context.GetPlace())) {
trans(context.template device_context<DeviceContext>(), x_data, #ifdef __NVCC__
x_data + numel, out_data, ClipFunctor<T>(min, max)); int threads = 256;
int blocks = (numel + threads - 1) / threads;
ClipCudaKernel<T, ClipFunctor<T>><<<
blocks, threads, 0,
context.template device_context<platform::CUDADeviceContext>()
.stream()>>>(x_data, out_data, numel, ClipFunctor<T>(min, max));
#endif
} else {
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
}
} else if (x_var->IsType<framework::SelectedRows>()) { } else if (x_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X"); auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out"); auto* out = context.Output<framework::SelectedRows>("Out");
......
...@@ -197,6 +197,40 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, ...@@ -197,6 +197,40 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
} }
#ifdef __NVCC__ #ifdef __NVCC__
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre,
int n, int post, int total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int idx = tid / post % n;
if (tid < total) {
out[tid] = func(x[tid], y[idx]);
}
}
template <typename Functor, typename T, typename OutType>
void ComputeElementwiseCUDA(const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
int pre, int n, int post,
const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, pre, n, post, numel, func);
} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, pre, n, post, numel, func);
}
}
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
__global__ void CommonForwardBroadcastCUDAKernel( __global__ void CommonForwardBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array, const int *x_strides_array, const int *y_strides_array,
...@@ -1908,6 +1942,16 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -1908,6 +1942,16 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
return; return;
} }
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ComputeElementwiseCUDA<Functor, T, OutType>(
x, y, z, pre, n, post,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
return;
}
if (post == 1) { if (post == 1) {
functor.RunRowWise(n, pre); functor.RunRowWise(n, pre);
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册