未验证 提交 aaa14780 编写于 作者: H Huang Jiyi 提交者: GitHub

register fluid activation kernel to phi (#51927)

* update

* update

* update

* update

* update

* fix test
上级 2add31f4
...@@ -518,6 +518,8 @@ function(op_library TARGET) ...@@ -518,6 +518,8 @@ function(op_library TARGET)
foreach(xpu_kp_src ${xpu_kp_cc_srcs}) foreach(xpu_kp_src ${xpu_kp_cc_srcs})
set(op_name "") set(op_name "")
find_register(${xpu_kp_src} "REGISTER_OP_KERNEL" op_name) find_register(${xpu_kp_src} "REGISTER_OP_KERNEL" op_name)
find_phi_register(${xpu_kp_src} ${pybind_file}
"PD_REGISTER_STRUCT_KERNEL")
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, KP);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, KP);\n")
message(STATUS "Building KP Target: ${op_name}") message(STATUS "Building KP Target: ${op_name}")
......
...@@ -166,7 +166,7 @@ TEST(DisMultiTrainerTest, test3) { ...@@ -166,7 +166,7 @@ TEST(DisMultiTrainerTest, test3) {
tmp1->SetDebug(true); tmp1->SetDebug(true);
ProgramDesc p; ProgramDesc p;
tmp1->InitOtherEnv(p); tmp1->InitOtherEnv(p);
tmp1->Run(); // tmp1->Run();
tmp1->Finalize(); tmp1->Finalize();
#endif #endif
} }
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -384,6 +385,18 @@ DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer, ...@@ -384,6 +385,18 @@ DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
{"DDX", "D_DOut"}); {"DDX", "D_DOut"});
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
#define DEFINE_ACTIVATION_CPU_KERNEL(op_name, functor, grad_functor) \
template <typename T, typename DeviceContext> \
class op_name##Kernel : public ActivationKernel<DeviceContext, functor<T>> { \
}; \
\
template <typename T, typename DeviceContext> \
class op_name##GradKernel \
: public ActivationGradKernel<DeviceContext, grad_functor<T>> {};
DEFINE_ACTIVATION_CPU_KERNEL(SoftRelu, SoftReluFunctor, SoftReluGradFunctor)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -407,19 +420,19 @@ namespace plat = paddle::platform; ...@@ -407,19 +420,19 @@ namespace plat = paddle::platform;
ops::ActivationOpGrad, \ ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInferer); ops::ActivationGradOpInplaceInferer);
#define REGISTER_ACTIVATION_CPU_KERNEL( \
act_type, op_name, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
ops::ActivationKernel<phi::CPUContext, ops::functor<float>>, \
ops::ActivationKernel<phi::CPUContext, ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name) \
PD_REGISTER_STRUCT_KERNEL( \
act_type, CPU, ALL_LAYOUT, ops::op_name##Kernel, float, double) {} \
PD_REGISTER_STRUCT_KERNEL(act_type##_grad, \
CPU, \
ALL_LAYOUT, \
ops::op_name##GradKernel, \
float, \
double) {}
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
......
...@@ -192,87 +192,41 @@ template <typename T> ...@@ -192,87 +192,41 @@ template <typename T>
using CudaELUGradNegativeAlphaFunctor = using CudaELUGradNegativeAlphaFunctor =
phi::funcs::CudaELUGradNegativeAlphaFunctor<T>; phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
#define DEFINE_ACTIVATION_CUDA_KERNEL(op_name, functor, grad_functor) \
template <typename T, typename DeviceContext> \
class op_name##CudaKernel \
: public ActivationCudaKernel<DeviceContext, functor<T>> {}; \
\
template <typename T, typename DeviceContext> \
class op_name##GradCudaKernel \
: public ActivationGradCudaKernel<DeviceContext, grad_functor<T>> {};
DEFINE_ACTIVATION_CUDA_KERNEL(SoftRelu,
CudaSoftReluFunctor,
CudaSoftReluGradFunctor)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL( \ PD_REGISTER_STRUCT_KERNEL(soft_relu,
act_type, op_name, functor, grad_functor) \ GPU,
REGISTER_OP_CUDA_KERNEL( \ ALL_LAYOUT,
act_type, \ ops::SoftReluCudaKernel,
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<float>>, \ float,
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<double>>, \ double,
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<plat::float16>>, \ plat::float16,
ops::ActivationCudaKernel<phi::GPUContext, \ plat::bfloat16) {}
ops::functor<plat::bfloat16>>); \ PD_REGISTER_STRUCT_KERNEL(soft_relu_grad,
REGISTER_OP_CUDA_KERNEL( \ GPU,
act_type##_grad, \ ALL_LAYOUT,
ops::ActivationGradCudaKernel<phi::GPUContext, \ ops::SoftReluGradCudaKernel,
ops::grad_functor<float>>, \ float,
ops::ActivationGradCudaKernel<phi::GPUContext, \ double,
ops::grad_functor<double>>, \ plat::float16,
ops::ActivationGradCudaKernel<phi::GPUContext, \ plat::bfloat16) {}
ops::grad_functor<plat::float16>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<plat::bfloat16>>);
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT( \
act_type, op_name, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<float>>, \
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<double>>, \
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<int>>, \
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<int64_t>>, \
ops::ActivationCudaKernel<phi::GPUContext, ops::functor<plat::float16>>, \
ops::ActivationCudaKernel<phi::GPUContext, \
ops::functor<plat::bfloat16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<float>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<double>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, ops::grad_functor<int>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<int64_t>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<plat::float16>>, \
ops::ActivationGradCudaKernel<phi::GPUContext, \
ops::grad_functor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu6,
ops::ActivationCudaKernel<phi::GPUContext, ops::CudaRelu6Functor<float>>,
ops::ActivationCudaKernel<phi::GPUContext, ops::CudaRelu6Functor<double>>,
ops::ActivationCudaKernel<phi::GPUContext, ops::CudaRelu6Functor<int>>,
ops::ActivationCudaKernel<phi::GPUContext, ops::CudaRelu6Functor<int64_t>>,
ops::ActivationCudaKernel<phi::GPUContext,
ops::CudaRelu6Functor<plat::float16>>,
ops::ActivationCudaKernel<phi::GPUContext,
ops::CudaRelu6Functor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu6_grad,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<float>>,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<double>>,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<int>>,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<phi::GPUContext,
ops::CudaRelu6GradFunctor<plat::bfloat16>>);
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
......
...@@ -2443,6 +2443,9 @@ class TestSoftRelu(TestActivation): ...@@ -2443,6 +2443,9 @@ class TestSoftRelu(TestActivation):
self.attrs = {'threshold': threshold} self.attrs = {'threshold': threshold}
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_output(self):
self.check_output(check_dygraph=False)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
...@@ -3856,6 +3859,7 @@ def create_test_act_fp16_class( ...@@ -3856,6 +3859,7 @@ def create_test_act_fp16_class(
parent, parent,
atol=1e-3, atol=1e-3,
grad_check=True, grad_check=True,
check_dygraph=True,
check_prim=False, check_prim=False,
enable_cinn=True, enable_cinn=True,
grad_atol=0.80, grad_atol=0.80,
...@@ -3875,7 +3879,10 @@ def create_test_act_fp16_class( ...@@ -3875,7 +3879,10 @@ def create_test_act_fp16_class(
support_fp16 = core.is_float16_supported(place) support_fp16 = core.is_float16_supported(place)
if support_fp16: if support_fp16:
self.check_output_with_place( self.check_output_with_place(
place, atol=atol, check_prim=check_prim place,
atol=atol,
check_dygraph=check_dygraph,
check_prim=check_prim,
) )
def test_check_grad(self): def test_check_grad(self):
...@@ -3886,6 +3893,7 @@ def create_test_act_fp16_class( ...@@ -3886,6 +3893,7 @@ def create_test_act_fp16_class(
place, place,
['X'], ['X'],
'Out', 'Out',
check_dygraph=check_dygraph,
check_prim=check_prim, check_prim=check_prim,
max_relative_error=grad_atol, max_relative_error=grad_atol,
) )
...@@ -3925,7 +3933,7 @@ create_test_act_fp16_class(TestRelu, check_prim=True) ...@@ -3925,7 +3933,7 @@ create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(TestGelu, check_prim=True, enable_cinn=False) create_test_act_fp16_class(TestGelu, check_prim=True, enable_cinn=False)
create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85) create_test_act_fp16_class(TestSoftRelu, check_dygraph=False, grad_atol=0.85)
create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestCELU) create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestReciprocal) create_test_act_fp16_class(TestReciprocal)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册