From c4d5a77fec998ea21870d6479a0584daccf4aa0e Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Wed, 13 Apr 2022 10:21:21 +0800 Subject: [PATCH] concat and relu sopport FP16 in XPU, test=kunlun (#41631) --- paddle/fluid/operators/activation_op_xpu.cc | 8 ++++- paddle/fluid/operators/concat_op_xpu.cc | 31 +++++++++++++------ .../fluid/platform/device/xpu/xpu2_op_list.h | 12 ++++--- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 4c2d3fc162f..e950f952c24 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor, XPULeakyReluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor, XPUReciprocalGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, XPUSigmoidGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) @@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor, REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor) +REGISTER_OP_XPU_KERNEL( + relu, ops::XPUActivationKernel<ops::XPUReluFunctor<float>>, + ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>); +REGISTER_OP_XPU_KERNEL( + relu_grad, ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>, + ops::XPUActivationGradKernel< + ops::XPUReluGradFunctor<paddle::platform::float16>>); REGISTER_OP_XPU_KERNEL( tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>, ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>); diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc index e4b0b0ee2e3..ba35098bbac 100644 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -26,6 +26,8 @@ using Tensor = framework::Tensor; template <typename DeviceContext, typename T> class ConcatXPUKernel : public framework::OpKernel<T> { + using XPUType = typename XPUTypeTrait<T>::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput<framework::LoDTensor>("X"); @@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel<T> { auto place = ctx.GetPlace(); out->mutable_data<T>(place); std::vector<std::vector<int>> xdims_list; - std::vector<const T*> ptrs; + std::vector<const XPUType*> ptrs; for (unsigned int i = 0; i < ins.size(); ++i) { if (ins[i] && ins[i]->numel() > 0) { - ptrs.push_back(ins[i]->data<T>()); + ptrs.push_back(reinterpret_cast<const XPUType*>(ins[i]->data<T>())); int size = ins[i]->dims().size(); std::vector<int> tmp_dims(size); for (int j = 0; j < size; ++j) { @@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel<T> { "No tensor need concat")); auto& dev_ctx = ctx.template device_context<DeviceContext>(); - int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(), - xdims_list, axis); + int r = xpu::concat<XPUType>(dev_ctx.x_context(), ptrs, + reinterpret_cast<XPUType*>(out->data<T>()), + xdims_list, axis); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU concat kernel return wrong value[%d %s]", r, @@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> { template <typename DeviceContext, typename T> class ConcatGradXPUKernel : public framework::OpKernel<T> { + using XPUType = typename XPUTypeTrait<T>::Type; + public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_grad = @@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { axis = ComputeAxis(static_cast<int64_t>(axis), static_cast<int64_t>(ins[0]->dims().size())); // get output tensor that the name is not kEmptyVarName - std::vector<T*> ptrs(outs.size()); + std::vector<XPUType*> ptrs(outs.size()); for (size_t j = 0; j < outs.size(); ++j) { if (out_var_names[j] != framework::kEmptyVarName && outs[j]->numel() != 0UL) { outs[j]->mutable_data<T>(ctx.GetPlace()); - ptrs[j] = outs[j]->data<T>(); + ptrs[j] = reinterpret_cast<XPUType*>(outs[j]->data<T>()); } else { ptrs[j] = nullptr; } @@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { xdims_list[axis] = total_length; auto& dev_ctx = ctx.template device_context<DeviceContext>(); - int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs, - xdims_list, split_list, axis); + int r = xpu::split<XPUType>( + dev_ctx.x_context(), + reinterpret_cast<const XPUType*>(out_grad->data<T>()), ptrs, xdims_list, + split_list, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( @@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( - concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>); + concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>, + ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, + paddle::platform::float16>); REGISTER_OP_XPU_KERNEL( concat_grad, - ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>); + ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>, + ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, + paddle::platform::float16>); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3a047b8fce7..9915b4d8d34 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), @@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() { {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"reshape2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), -- GitLab