diff --git a/paddle/fluid/operators/abs_op.cu b/paddle/fluid/operators/abs_op.cu index e373d628f6cbd6b5ee48edc984a68d2767ce0593..97409e6cb1b17b8fc109e30dc78720b8d573f042 100644 --- a/paddle/fluid/operators/abs_op.cu +++ b/paddle/fluid/operators/abs_op.cu @@ -13,44 +13,79 @@ // limitations under the License. #include "paddle/fluid/operators/abs_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +template +struct CudaAbsFunctor; + +template +struct CudaAbsFunctor>> { + __device__ __forceinline__ math::Real operator()(const T* args) const { + return abs(args[0]); + } +}; + +template +struct CudaAbsFunctor>> { + __device__ __forceinline__ T operator()(const T* args) const { + return std::abs(args[0]); + } +}; + +template +class AbsKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + out->mutable_data>(context.GetPlace()); + + auto& dev_ctx = + context.template device_context(); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = CudaAbsFunctor(); + LaunchElementwiseCudaKernel>( + dev_ctx, ins, &outs, functor); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; +namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL( - abs, ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel); + abs, ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel, + ops::AbsKernel); REGISTER_OP_CUDA_KERNEL( - abs_grad, ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel); + abs_grad, ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel, + ops::AbsGradKernel); REGISTER_OP_CUDA_KERNEL( - abs_grad_grad, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel); + abs_grad_grad, ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel, + ops::AbsDoubleGradKernel); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 22f8147111ffa5be91813738ff147a19b9ef22bc..618f17031b1ef3b4b96ea72b05f9f63edd01c794 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -1315,8 +1315,8 @@ class ActivationCudaKernel for (auto& attr : attrs) { *attr.second = ctx.Attr(attr.first); } - LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, - functor); + LaunchElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } }; @@ -1345,17 +1345,17 @@ class ActivationGradCudaKernel if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { // Only need forward output Out ins.push_back(out); - LaunchElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } else if (static_cast(Functor::FwdDeps()) == static_cast(kDepX)) { // Only need forward input X ins.push_back(x); - LaunchElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } else { - LaunchElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 5c444e752e797571e525f9f4b0319146988c7683..dc9c18ba038861b763cb52863ddae8ac69db5022 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -45,7 +45,7 @@ struct SameDimsElemwiseAdd { framework::Tensor* z) { std::vector ins = {x, y}; std::vector outs = {z}; - LaunchElementwiseCudaKernel( + LaunchElementwiseCudaKernel( ctx.template device_context(), ins, &outs, CudaAddFunctor()); } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 321826ec647c99345ac0769c88ac4ffa2be5b0db..38b1afbdc3342e8bc4d9901b64bae808fd9d3915 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -49,69 +49,73 @@ int GetVectorizedSizeImpl(const T *pointer) { return 1; } -template +template int GetVectorizedSize(const std::vector &ins, const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); } for (auto iter = outs.begin(); iter != outs.end(); ++iter) { vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); } return vec_size; } -template +template struct ElementwiseDataWrapper { - T *out; - const T *in0; - const T *in1; - __device__ ElementwiseDataWrapper(T *out, const T *in0, - const T *in1 = nullptr) + OutT *out; + const InT *in0; + const InT *in1; + __device__ ElementwiseDataWrapper(OutT *out, const InT *in0, + const InT *in1 = nullptr) : out(out), in0(in0), in1(in1) {} - using VecType = CudaAlignedVector; + using InVecType = CudaAlignedVector; + using OutVecType = CudaAlignedVector; - inline __device__ void load_vector(VecType args[], int idx) { - const VecType *x_vec = reinterpret_cast(in0); + inline __device__ void load_vector(InVecType args[], int idx) { + const InVecType *x_vec = reinterpret_cast(in0); args[0] = x_vec[idx]; if (ET == ElementwiseType::kBinary) { - const VecType *y_vec = reinterpret_cast(in1); + const InVecType *y_vec = reinterpret_cast(in1); args[1] = y_vec[idx]; } } - inline __device__ void load_scalar(T args[], int idx) { + inline __device__ void load_scalar(InT args[], int idx) { args[0] = in0[idx]; if (ET == ElementwiseType::kBinary) { args[1] = in1[idx]; } } - inline __device__ void store_vector(VecType res, int idx) { - VecType *out_vec = reinterpret_cast(out); + inline __device__ void store_vector(OutVecType res, int idx) { + OutVecType *out_vec = reinterpret_cast(out); out_vec[idx] = res; } - inline __device__ void store_scalar(T res, int idx) { out[idx] = res; } + inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; } }; -template +template __device__ void VectorizedKernelImpl( - ElementwiseDataWrapper data, Functor func, int tid) { - using VecType = CudaAlignedVector; - VecType ins_vec[ET]; - VecType out_vec; - T *ins_ptr[ET]; - T *out_ptr; + ElementwiseDataWrapper data, Functor func, + int tid) { + using InVecType = CudaAlignedVector; + using OutVecType = CudaAlignedVector; + InVecType ins_vec[ET]; + OutVecType out_vec; + InT *ins_ptr[ET]; + OutT *out_ptr; #pragma unroll for (int i = 0; i < ET; ++i) { - ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); + ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); } - out_ptr = reinterpret_cast(&out_vec); + out_ptr = reinterpret_cast(&out_vec); // load data.load_vector(ins_vec, tid); @@ -119,7 +123,7 @@ __device__ void VectorizedKernelImpl( // compute #pragma unroll for (int i = 0; i < VecSize; ++i) { - T ins[ET]; + InT ins[ET]; #pragma unroll for (int j = 0; j < ET; ++j) { ins[j] = ins_ptr[j][i]; @@ -131,11 +135,13 @@ __device__ void VectorizedKernelImpl( data.store_vector(out_vec, tid); } -template -__device__ void ScalarKernelImpl(ElementwiseDataWrapper data, - Functor func, int start, int remain) { - T ins[ET]; - T out; +template +__device__ void ScalarKernelImpl( + ElementwiseDataWrapper data, Functor func, + int start, int remain) { + InT ins[ET]; + OutT out; for (int i = 0; i < remain; ++i) { int idx = start + i; @@ -148,14 +154,15 @@ __device__ void ScalarKernelImpl(ElementwiseDataWrapper data, } } -template -__global__ void VectorizedKernel(const T *__restrict__ in0, - const T *__restrict__ in1, T *out, int size, - Functor func) { +template +__global__ void VectorizedKernel(const InT *__restrict__ in0, + const InT *__restrict__ in1, OutT *out, + int size, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int remain = size - VecSize * tid; remain = remain > 0 ? remain : 0; - auto data = ElementwiseDataWrapper(out, in0, in1); + auto data = ElementwiseDataWrapper(out, in0, in1); if (remain >= VecSize) { VectorizedKernelImpl(data, func, tid); } else { @@ -163,30 +170,31 @@ __global__ void VectorizedKernel(const T *__restrict__ in0, } } -template -__global__ void ScalarKernel(const T *__restrict__ in0, - const T *__restrict__ in1, T *out, int size, +template +__global__ void ScalarKernel(const InT *__restrict__ in0, + const InT *__restrict__ in1, OutT *out, int size, Functor func) { - auto data = ElementwiseDataWrapper(out, in0, in1); + auto data = ElementwiseDataWrapper(out, in0, in1); int tid = blockIdx.x * blockDim.x + threadIdx.x; int remain = tid < size ? 1 : 0; ScalarKernelImpl(data, func, tid, remain); } -template +template void LaunchElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs auto size = ins[0]->numel(); - int vec_size = GetVectorizedSize(ins, *outs); + int vec_size = GetVectorizedSize(ins, *outs); int block_size = ELEMENTWISE_BLOCK_SIZE; int grid_size = ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; - const T *in0 = ins[0]->data(); - const T *in1 = (ET == ElementwiseType::kBinary) ? ins[1]->data() : nullptr; - T *out = (*outs)[0]->data(); + const InT *in0 = ins[0]->data(); + const InT *in1 = + (ET == ElementwiseType::kBinary) ? ins[1]->data() : nullptr; + OutT *out = (*outs)[0]->data(); // cuda kernel auto stream = ctx.stream(); switch (vec_size) {