diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu index 54b0d5b69086cda3ebdefa76636aff734d1a150c..61a1691e4fe265035917ed2407d5e3e24aa6bd88 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_op.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/where_op.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" @@ -20,6 +21,15 @@ namespace platform = paddle::platform; namespace paddle { namespace operators { +template +struct CondFunctor { + HOSTDEVICE inline CondFunctor() {} + + HOSTDEVICE inline T operator()(const bool cond, const T x, const T y) const { + return cond ? x : y; + } +}; + template __global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, const T* y, T* out) { @@ -63,10 +73,11 @@ class WhereKernel auto stream = context.cuda_device_context().stream(); auto& dev_ctx = context.template device_context(); - auto config = GetGpuLaunchConfig1D(dev_ctx, numel); - WhereCUDAKernel< - T><<>>( - numel, cond_data, x_data, y_data, out_data); + auto functor = CondFunctor(); + std::vector ins = {condition, X, Y}; + std::vector outs = {out}; + paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, + &outs, functor); } }; diff --git a/paddle/phi/kernels/funcs/complex_functors.h b/paddle/phi/kernels/funcs/complex_functors.h index 450adfcc68b7e84e27a2f6bf2c6c22551bab8892..86dbdd099ecde72e932cc6cfa492486b65c7ebc2 100644 --- a/paddle/phi/kernels/funcs/complex_functors.h +++ b/paddle/phi/kernels/funcs/complex_functors.h @@ -154,6 +154,53 @@ struct AbsFunctor>> { int64_t numel_; }; +template +struct AbsGradCUDAFunctor { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + + HOSTDEVICE inline T operator()(const T x, const T dout) const { + T output; + if (x == T(0)) { + output = T(0); + } else { + output = T(dout) * (x / T(std::abs(x))); + } + return output; + } +}; + +template <> +struct AbsGradCUDAFunctor> { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + HOSTDEVICE inline phi::dtype::complex operator()( + const phi::dtype::complex x, const float dout) const { + phi::dtype::complex output; + if (x == phi::dtype::complex(0)) { + output = phi::dtype::complex(0); + } else { + output = phi::dtype::complex(dout) * + (x / phi::dtype::complex(abs(x))); + } + return output; + } +}; + +template <> +struct AbsGradCUDAFunctor> { + HOSTDEVICE inline AbsGradCUDAFunctor() {} + HOSTDEVICE inline phi::dtype::complex operator()( + const phi::dtype::complex x, const double dout) const { + phi::dtype::complex output; + if (x == phi::dtype::complex(0)) { + output = phi::dtype::complex(0); + } else { + output = phi::dtype::complex(dout) * + (x / phi::dtype::complex(abs(x))); + } + return output; + } +}; + template struct AbsGradFunctor { AbsGradFunctor(const Real* dout, const T* x, T* output, int64_t numel) diff --git a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h index 939bc49c9fc671ac148688ca6556e982d8ee5523..4b31393a71f3623bff168dfc17612ceda250c506 100644 --- a/paddle/phi/kernels/impl/abs_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/abs_grad_kernel_impl.h @@ -17,9 +17,30 @@ #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/abs_grad_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" namespace phi { +#if defined(__NVCC__) +template +void AbsGradKernelImpl(const GPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + std::vector ins = {&x, &dout}; + std::vector outs = {dx}; + dev_ctx.Alloc(dx); + phi::funcs::AbsGradCUDAFunctor abs_grad_cuda_functor; + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, abs_grad_cuda_functor); +} +template +void AbsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + AbsGradKernelImpl(dev_ctx, x, dout, dx); +} +#else template void AbsGradKernel(const Context& ctx, const DenseTensor& x, @@ -37,6 +58,7 @@ void AbsGradKernel(const Context& ctx, for_range(functor); } +#endif template void AbsDoubleGradKernel(const Context& ctx, const DenseTensor& x,