From c4d5a77fec998ea21870d6479a0584daccf4aa0e Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Wed, 13 Apr 2022 10:21:21 +0800 Subject: [PATCH] concat and relu sopport FP16 in XPU, test=kunlun (#41631) --- paddle/fluid/operators/activation_op_xpu.cc | 8 ++++- paddle/fluid/operators/concat_op_xpu.cc | 31 +++++++++++++------ .../fluid/platform/device/xpu/xpu2_op_list.h | 12 ++++--- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 4c2d3fc162..e950f952c2 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor, XPULeakyReluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor, XPUReciprocalGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, XPUSigmoidGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) @@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor, REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor) +REGISTER_OP_XPU_KERNEL( + relu, ops::XPUActivationKernel>, + ops::XPUActivationKernel>); +REGISTER_OP_XPU_KERNEL( + relu_grad, ops::XPUActivationGradKernel>, + ops::XPUActivationGradKernel< + ops::XPUReluGradFunctor>); REGISTER_OP_XPU_KERNEL( tanh, ops::XPUActivationKernel>, ops::XPUActivationKernel>); diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc index e4b0b0ee2e..ba35098bba 100644 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -26,6 +26,8 @@ using Tensor = framework::Tensor; template class ConcatXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); @@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(place); std::vector> xdims_list; - std::vector ptrs; + std::vector ptrs; for (unsigned int i = 0; i < ins.size(); ++i) { if (ins[i] && ins[i]->numel() > 0) { - ptrs.push_back(ins[i]->data()); + ptrs.push_back(reinterpret_cast(ins[i]->data())); int size = ins[i]->dims().size(); std::vector tmp_dims(size); for (int j = 0; j < size; ++j) { @@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel { "No tensor need concat")); auto& dev_ctx = ctx.template device_context(); - int r = xpu::concat(dev_ctx.x_context(), ptrs, out->data(), - xdims_list, axis); + int r = xpu::concat(dev_ctx.x_context(), ptrs, + reinterpret_cast(out->data()), + xdims_list, axis); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU concat kernel return wrong value[%d %s]", r, @@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel { template class ConcatGradXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const { auto* out_grad = @@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel { axis = ComputeAxis(static_cast(axis), static_cast(ins[0]->dims().size())); // get output tensor that the name is not kEmptyVarName - std::vector ptrs(outs.size()); + std::vector ptrs(outs.size()); for (size_t j = 0; j < outs.size(); ++j) { if (out_var_names[j] != framework::kEmptyVarName && outs[j]->numel() != 0UL) { outs[j]->mutable_data(ctx.GetPlace()); - ptrs[j] = outs[j]->data(); + ptrs[j] = reinterpret_cast(outs[j]->data()); } else { ptrs[j] = nullptr; } @@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel { xdims_list[axis] = total_length; auto& dev_ctx = ctx.template device_context(); - int r = xpu::split(dev_ctx.x_context(), out_grad->data(), ptrs, - xdims_list, split_list, axis); + int r = xpu::split( + dev_ctx.x_context(), + reinterpret_cast(out_grad->data()), ptrs, xdims_list, + split_list, axis); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( @@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( - concat, ops::ConcatXPUKernel); + concat, ops::ConcatXPUKernel, + ops::ConcatXPUKernel); REGISTER_OP_XPU_KERNEL( concat_grad, - ops::ConcatGradXPUKernel); + ops::ConcatGradXPUKernel, + ops::ConcatGradXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3a047b8fce..9915b4d8d3 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), @@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() { {"reduce_sum_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"reshape2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), -- GitLab