diff --git a/cmake/operators.cmake b/cmake/operators.cmake index bf40948cf7c791ebb8c93fd8b8a21263e8cc559c..1897ded550d00118517e504c5490d0da932e54fa 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -219,6 +219,16 @@ function(op_library TARGET) list(APPEND mlu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) + elseif((WITH_ROCM OR WITH_GPU) AND ${src} MATCHES ".*\\.kps$") + string(REPLACE ".kps" ".cu" src_cu ${src}) + file(COPY ${src} DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${src} + ${CMAKE_CURRENT_BINARY_DIR}/${src_cu}) + if(WITH_ROCM) + list(APPEND hip_srcs ${CMAKE_CURRENT_BINARY_DIR}/${src_cu}) + else() + list(APPEND cu_srcs ${CMAKE_CURRENT_BINARY_DIR}/${src_cu}) + endif() else() message( FATAL_ERROR diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c56bb972b24fa8f5350b4cde9f184d3c6ad0f0b4..d39aeedd45908bb26193a85d9c762741d07b34c8 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -97,7 +97,7 @@ endif() set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api get_expected_kernel_func) register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op eye_op quantize_linear_op - recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) + recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) @@ -106,6 +106,14 @@ op_library(quantize_linear_op DEPS phi) op_library(save_combine_op DEPS string_array phi) op_library(load_combine_op DEPS string_array) +if (WITH_GPU OR WITH_ROCM) + op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS}) +elseif (WITH_XPU_KP) + op_library(activation_op SRCS activation_op.cc activation_op.kps DEPS ${OP_HEADER_DEPS}) +else() + op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS}) +endif() + if (WITH_GPU OR WITH_ROCM) if(WITH_ROCM) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc) diff --git a/paddle/fluid/operators/activation_op.cu.h b/paddle/fluid/operators/activation_op.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..08a8d9a08960936bcd2d660cc368d4fda0a61184 --- /dev/null +++ b/paddle/fluid/operators/activation_op.cu.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" + +namespace paddle { +namespace operators { + +template +struct CudaSoftReluFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold))) + // threshold should not be negative + __device__ __forceinline__ T operator()(const T arg_x) const { + MPType x = static_cast(arg_x); + MPType t = static_cast(threshold); + MPType temp_min = x < t ? x : t; + MPType temp_max = temp_min > -t ? temp_min : -t; + return static_cast(log(one + exp(temp_max))); + } +}; + +template +struct CudaSoftReluGradFunctor : public BaseActivationFunctor { + using MPType = typename details::MPTypeTrait::Type; + MPType one = static_cast(1.0f); + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0 + // threshold should not be negative + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_out) const { + MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); + MPType t = static_cast(threshold); + return (out > -t && out < t) ? static_cast(dout * (one - exp(-out))) + : static_cast(0.0f); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +class ActivationCudaKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const phi::DenseTensor* x = nullptr; + phi::DenseTensor* out = nullptr; + ExtractActivationTensor(ctx, &x, &out); + out->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = Functor(); + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } +}; + +template +class ActivationGradCudaKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const phi::DenseTensor *x, *out, *d_out; + phi::DenseTensor* d_x = nullptr; + x = out = d_out = nullptr; + ExtractActivationGradTensor( + ctx, &x, &out, &d_out, &d_x); + d_x->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + auto functor = Functor(); + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + + std::vector ins = {d_out}; + std::vector outs = {d_x}; + + if (static_cast(Functor::FwdDeps()) == + static_cast(ActBwdOpFwdDeps::kDepOut)) { + // Only need forward output Out + ins.push_back(out); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } else if (static_cast(Functor::FwdDeps()) == + static_cast(ActBwdOpFwdDeps::kDepX)) { + // Only need forward input X + ins.push_back(x); + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } else { + paddle::operators::LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/activation_op.kps b/paddle/fluid/operators/activation_op.kps index 41066cfa97925099dfba742cfc4aa9bbdcd9ecb6..e0113031a4b1d0882a4953adaa61cfbed98453c2 100644 --- a/paddle/fluid/operators/activation_op.kps +++ b/paddle/fluid/operators/activation_op.kps @@ -9,127 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/phi/backends/gpu/gpu_device_function.h" -#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/fluid/operators/activation_op.cu.h" namespace paddle { namespace operators { -template -struct CudaSoftReluFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold))) - // threshold should not be negative - __device__ __forceinline__ T operator()(const T arg_x) const { - MPType x = static_cast(arg_x); - MPType t = static_cast(threshold); - MPType temp_min = x < t ? x : t; - MPType temp_max = temp_min > -t ? temp_min : -t; - return static_cast(log(one + exp(temp_max))); - } -}; - -template -struct CudaSoftReluGradFunctor : public BaseActivationFunctor { - using MPType = typename details::MPTypeTrait::Type; - MPType one = static_cast(1.0f); - float threshold; - - typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; - } - - // dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0 - // threshold should not be negative - __device__ __forceinline__ T operator()(const T arg_dout, - const T arg_out) const { - MPType dout = static_cast(arg_dout); - MPType out = static_cast(arg_out); - MPType t = static_cast(threshold); - return (out > -t && out < t) ? static_cast(dout * (one - exp(-out))) - : static_cast(0.0f); - } - - static constexpr ActBwdOpFwdDeps FwdDeps() { - return ActBwdOpFwdDeps::kDepOut; - } -}; - -template -class ActivationCudaKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const phi::DenseTensor* x = nullptr; - phi::DenseTensor* out = nullptr; - ExtractActivationTensor(ctx, &x, &out); - out->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - std::vector ins = {x}; - std::vector outs = {out}; - auto functor = Functor(); - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = ctx.Attr(attr.first); - } - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); - } -}; - -template -class ActivationGradCudaKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - void Compute(const framework::ExecutionContext& ctx) const override { - const phi::DenseTensor *x, *out, *d_out; - phi::DenseTensor* d_x = nullptr; - x = out = d_out = nullptr; - ExtractActivationGradTensor( - ctx, &x, &out, &d_out, &d_x); - d_x->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - auto functor = Functor(); - auto attrs = functor.GetAttrs(); - for (auto& attr : attrs) { - *attr.second = ctx.Attr(attr.first); - } - - std::vector ins = {d_out}; - std::vector outs = {d_x}; - - if (static_cast(Functor::FwdDeps()) == - static_cast(ActBwdOpFwdDeps::kDepOut)) { - // Only need forward output Out - ins.push_back(out); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); - } else if (static_cast(Functor::FwdDeps()) == - static_cast(ActBwdOpFwdDeps::kDepX)) { - // Only need forward input X - ins.push_back(x); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); - } else { - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - dev_ctx, ins, &outs, functor); - } - } -}; - template using CudaBReluFunctor = phi::funcs::CudaHardTanhFunctor; template @@ -192,42 +76,12 @@ template using CudaELUGradNegativeAlphaFunctor = phi::funcs::CudaELUGradNegativeAlphaFunctor; -#define DEFINE_ACTIVATION_CUDA_KERNEL(op_name, functor, grad_functor) \ - template \ - class op_name##CudaKernel \ - : public ActivationCudaKernel> {}; \ - \ - template \ - class op_name##GradCudaKernel \ - : public ActivationGradCudaKernel> {}; - -DEFINE_ACTIVATION_CUDA_KERNEL(SoftRelu, - CudaSoftReluFunctor, - CudaSoftReluGradFunctor) - } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(soft_relu, - GPU, - ALL_LAYOUT, - ops::SoftReluCudaKernel, - float, - double, - plat::float16, - plat::bfloat16) {} -PD_REGISTER_STRUCT_KERNEL(soft_relu_grad, - GPU, - ALL_LAYOUT, - ops::SoftReluGradCudaKernel, - float, - double, - plat::float16, - plat::bfloat16) {} - #ifdef PADDLE_WITH_XPU_KP REGISTER_OP_KERNEL( brelu, diff --git a/paddle/fluid/operators/soft_relu_op.cu b/paddle/fluid/operators/soft_relu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3963b372c9c8eb543f5188a0efe5dbf03f5088b7 --- /dev/null +++ b/paddle/fluid/operators/soft_relu_op.cu @@ -0,0 +1,51 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/activation_op.cu.h" + +namespace paddle { +namespace operators { + +#define DEFINE_ACTIVATION_CUDA_KERNEL(op_name, functor, grad_functor) \ + template \ + class op_name##CudaKernel \ + : public ActivationCudaKernel> {}; \ + \ + template \ + class op_name##GradCudaKernel \ + : public ActivationGradCudaKernel> {}; + +DEFINE_ACTIVATION_CUDA_KERNEL(SoftRelu, + CudaSoftReluFunctor, + CudaSoftReluGradFunctor) + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(soft_relu, + GPU, + ALL_LAYOUT, + ops::SoftReluCudaKernel, + float, + double, + plat::float16, + plat::bfloat16) {} +PD_REGISTER_STRUCT_KERNEL(soft_relu_grad, + GPU, + ALL_LAYOUT, + ops::SoftReluGradCudaKernel, + float, + double, + plat::float16, + plat::bfloat16) {}