diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index e568de73c0abf2e5ad0d4aeaa2edd510a3dc09b4..b13a74ae72d3a0e0d52ecc08deaeb2b1599d255e 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -42,8 +42,7 @@ if(WITH_XPU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_xpu.cc) - detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc - generate_proposals_v2_op_xpu.cc) + detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) elseif(WITH_MLU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_mlu.cc) diff --git a/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc b/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc deleted file mode 100644 index 8b513d69c260733c84e9deb5ae5975871cb4c319..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detection/generate_proposals_v2_op_xpu.cc +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include -#include - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -namespace { -template -static void SortDescending(const platform::XPUDeviceContext &dev_ctx, - const Tensor &value, - Tensor *index_out, - int pre_nms_top_n) { - auto *value_data = value.data(); - auto place = dev_ctx.GetPlace(); - auto cpu_place = platform::CPUPlace(); - - Tensor scores_slice_cpu; - scores_slice_cpu.Resize({value.numel()}); - auto *scores_slice_cpu_data = scores_slice_cpu.mutable_data(cpu_place); - - memory::Copy(cpu_place, - scores_slice_cpu_data, - place, - value_data, - sizeof(T) * value.numel()); - - // Sort index - Tensor index_t; - int *index = index_t.mutable_data({value.numel()}, cpu_place); - for (int i = 0; i < value.numel(); ++i) { - index[i] = i; - } - auto compare = [scores_slice_cpu_data](const int64_t &i, const int64_t &j) { - return scores_slice_cpu_data[i] > scores_slice_cpu_data[j]; - }; - - if (pre_nms_top_n <= 0 || pre_nms_top_n >= value.numel()) { - std::sort(index, index + value.numel(), compare); - } else { - std::nth_element( - index, index + pre_nms_top_n, index + value.numel(), compare); - std::sort(index, index + pre_nms_top_n, compare); - index_t.Resize({pre_nms_top_n}); - } - - int *idx_out = - index_out->mutable_data({index_t.numel()}, dev_ctx.GetPlace()); - memory::Copy(place, idx_out, cpu_place, index, sizeof(T) * index_t.numel()); -} - -template -static std::pair ProposalForOneImage( - const platform::XPUDeviceContext &dev_ctx, - const Tensor &im_shape, - const Tensor &anchors, - const Tensor &variances, - const Tensor &bbox_deltas, // [M, 4] - const Tensor &scores, // [N, 1] - int pre_nms_top_n, - int post_nms_top_n, - float nms_thresh, - float min_size, - float eta, - bool pixel_offset) { - // 1. pre nms - Tensor index_sort; - SortDescending(dev_ctx, scores, &index_sort, pre_nms_top_n); - - Tensor scores_sel, bbox_sel, anchor_sel, var_sel; - scores_sel.mutable_data({index_sort.numel(), 1}, dev_ctx.GetPlace()); - bbox_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); - anchor_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); - var_sel.mutable_data({index_sort.numel(), 4}, dev_ctx.GetPlace()); - - int r = xpu::gather(dev_ctx.x_context(), - scores.data(), - index_sort.data(), - scores_sel.data(), - {static_cast(scores.numel()), 1}, - index_sort.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - r = xpu::gather(dev_ctx.x_context(), - bbox_deltas.data(), - index_sort.data(), - bbox_sel.data(), - {static_cast(bbox_deltas.numel()) / 4, 4}, - index_sort.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - r = xpu::gather(dev_ctx.x_context(), - anchors.data(), - index_sort.data(), - anchor_sel.data(), - {static_cast(anchors.numel()) / 4, 4}, - index_sort.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - r = xpu::gather(dev_ctx.x_context(), - variances.data(), - index_sort.data(), - var_sel.data(), - {static_cast(variances.numel()) / 4, 4}, - index_sort.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - int num = scores.numel(); - int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel() - : pre_nms_top_n; - scores_sel.Resize({pre_nms_num, 1}); - index_sort.Resize({pre_nms_num, 1}); - - // 2. box decode and clipping - Tensor proposals; - proposals.mutable_data({pre_nms_num, 4}, dev_ctx.GetPlace()); - - r = xpu::box_decoder(dev_ctx.x_context(), - anchor_sel.data(), - var_sel.data(), - bbox_sel.data(), - proposals.data(), - pre_nms_num, - !pixel_offset, - true, - im_shape.data()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "box_decoder"); - - // 3. filter - Tensor keep_index, keep_num_t; - keep_index.mutable_data({pre_nms_num}, dev_ctx.GetPlace()); - keep_num_t.mutable_data({1}, dev_ctx.GetPlace()); - min_size = std::max(min_size, 1.0f); - r = xpu::remove_small_boxes(dev_ctx.x_context(), - proposals.data(), - im_shape.data(), - keep_index.data(), - keep_num_t.data(), - pre_nms_num, - min_size, - false, - pixel_offset); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "remove_small_boxes"); - int keep_num; - const auto xpu_place = dev_ctx.GetPlace(); - memory::Copy(platform::CPUPlace(), - &keep_num, - xpu_place, - keep_num_t.data(), - sizeof(int)); - keep_index.Resize({keep_num}); - - Tensor scores_filter, proposals_filter; - // Handle the case when there is no keep index left - if (keep_num == 0) { - phi::funcs::SetConstant set_zero; - proposals_filter.mutable_data({1, 4}, dev_ctx.GetPlace()); - scores_filter.mutable_data({1, 1}, dev_ctx.GetPlace()); - set_zero(dev_ctx, &proposals_filter, static_cast(0)); - set_zero(dev_ctx, &scores_filter, static_cast(0)); - return std::make_pair(proposals_filter, scores_filter); - } - proposals_filter.mutable_data({keep_num, 4}, dev_ctx.GetPlace()); - scores_filter.mutable_data({keep_num, 1}, dev_ctx.GetPlace()); - r = xpu::gather(dev_ctx.x_context(), - proposals.data(), - keep_index.data(), - proposals_filter.data(), - {pre_nms_num, 4}, - keep_num, - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - r = xpu::gather(dev_ctx.x_context(), - scores_sel.data(), - keep_index.data(), - scores_filter.data(), - {pre_nms_num, 1}, - keep_num, - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - - if (nms_thresh <= 0) { - if (dev_ctx.x_context()->xpu_stream) { - dev_ctx.Wait(); - } - return std::make_pair(proposals_filter, scores_filter); - } - - // 4. nms - int nms_keep_num = 0; - r = xpu::sorted_nms(dev_ctx.x_context(), - proposals_filter.data(), - keep_index.data(), - nms_keep_num, - keep_num, - nms_thresh, - pixel_offset); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_nms"); - if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) { - keep_index.Resize({post_nms_top_n}); - } else { - keep_index.Resize({nms_keep_num}); - } - - Tensor scores_nms, proposals_nms; - proposals_nms.mutable_data({keep_index.numel(), 4}, dev_ctx.GetPlace()); - scores_nms.mutable_data({keep_index.numel(), 1}, dev_ctx.GetPlace()); - r = xpu::gather(dev_ctx.x_context(), - proposals_filter.data(), - keep_index.data(), - proposals_nms.data(), - {keep_num, 4}, - keep_index.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - r = xpu::gather(dev_ctx.x_context(), - scores_filter.data(), - keep_index.data(), - scores_nms.data(), - {keep_num, 1}, - keep_index.numel(), - 0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); - if (dev_ctx.x_context()->xpu_stream) { - dev_ctx.Wait(); - } - return std::make_pair(proposals_nms, scores_nms); -} -} // namespace - -template -class XPUGenerateProposalsV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *scores = context.Input("Scores"); - auto *bbox_deltas = context.Input("BboxDeltas"); - auto *im_shape = context.Input("ImShape"); - auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), - "Input", - "Anchors", - "GenerateProposals"); - auto variances = GET_DATA_SAFELY(context.Input("Variances"), - "Input", - "Variances", - "GenerateProposals"); - - auto *rpn_rois = context.Output("RpnRois"); - auto *rpn_roi_probs = context.Output("RpnRoiProbs"); - - int pre_nms_top_n = context.Attr("pre_nms_topN"); - int post_nms_top_n = context.Attr("post_nms_topN"); - float nms_thresh = context.Attr("nms_thresh"); - float min_size = context.Attr("min_size"); - float eta = context.Attr("eta"); - bool pixel_offset = context.Attr("pixel_offset"); - PADDLE_ENFORCE_GE(eta, - 1., - platform::errors::InvalidArgument( - "Not support adaptive NMS. The attribute 'eta' " - "should not less than 1. But received eta=[%d]", - eta)); - - auto &dev_ctx = context.template device_context(); - - auto scores_dim = scores->dims(); - // the shape of bbox score - int num = scores_dim[0]; - int c_score = scores_dim[1]; - int h_score = scores_dim[2]; - int w_score = scores_dim[3]; - - auto bbox_dim = bbox_deltas->dims(); - int c_bbox = bbox_dim[1]; - int h_bbox = bbox_dim[2]; - int w_bbox = bbox_dim[3]; - - Tensor bbox_deltas_swap, scores_swap; - bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, - dev_ctx.GetPlace()); - scores_swap.mutable_data({num, h_score, w_score, c_score}, - dev_ctx.GetPlace()); - - std::vector axis = {0, 2, 3, 1}; - int r = xpu::transpose(dev_ctx.x_context(), - bbox_deltas->data(), - bbox_deltas_swap.data(), - {num, c_bbox, h_bbox, w_bbox}, - axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - r = xpu::transpose(dev_ctx.x_context(), - scores->data(), - scores_swap.data(), - {num, c_score, h_score, w_score}, - axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - anchors.Resize({anchors.numel() / 4, 4}); - variances.Resize({variances.numel() / 4, 4}); - - // output - rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, - context.GetPlace()); - rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); - - T *rpn_rois_data = rpn_rois->data(); - T *rpn_roi_probs_data = rpn_roi_probs->data(); - - auto place = dev_ctx.GetPlace(); - auto cpu_place = platform::CPUPlace(); - - int num_proposals = 0; - std::vector offset(1, 0); - std::vector tmp_num; - - for (int64_t i = 0; i < num; ++i) { - Tensor im_shape_slice = im_shape->Slice(i, i + 1); - Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); - Tensor scores_slice = scores_swap.Slice(i, i + 1); - - bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); - scores_slice.Resize({h_score * w_score * c_score, 1}); - - std::pair box_score_pair = - ProposalForOneImage(dev_ctx, - im_shape_slice, - anchors, - variances, - bbox_deltas_slice, - scores_slice, - pre_nms_top_n, - post_nms_top_n, - nms_thresh, - min_size, - eta, - pixel_offset); - - Tensor &proposals = box_score_pair.first; - Tensor &scores = box_score_pair.second; - - memory::Copy(place, - rpn_rois_data + num_proposals * 4, - place, - proposals.data(), - sizeof(T) * proposals.numel()); - memory::Copy(place, - rpn_roi_probs_data + num_proposals, - place, - scores.data(), - sizeof(T) * scores.numel()); - if (dev_ctx.x_context()->xpu_stream) { - dev_ctx.Wait(); - } - num_proposals += proposals.dims()[0]; - offset.emplace_back(num_proposals); - tmp_num.push_back(proposals.dims()[0]); - } - if (context.HasOutput("RpnRoisNum")) { - auto *rpn_rois_num = context.Output("RpnRoisNum"); - rpn_rois_num->mutable_data({num}, context.GetPlace()); - int *num_data = rpn_rois_num->data(); - memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num); - rpn_rois_num->Resize({num}); - } - framework::LoD lod; - lod.emplace_back(offset); - rpn_rois->set_lod(lod); - rpn_roi_probs->set_lod(lod); - rpn_rois->Resize({num_proposals, 4}); - rpn_roi_probs->Resize({num_proposals, 1}); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - generate_proposals_v2, - ops::XPUGenerateProposalsV2Kernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/gelu_op_xpu.cc b/paddle/fluid/operators/gelu_op_xpu.cc deleted file mode 100644 index 867830c72e87f1ed4dadbfbc02d04f8b21c84c99..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/gelu_op_xpu.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class GeluXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - - auto* out = ctx.Output("Out"); - - auto place = ctx.GetPlace(); - - const XPUType* x_data = reinterpret_cast(x->data()); - XPUType* y_data = reinterpret_cast(out->mutable_data(place)); - auto& dev_ctx = ctx.template device_context(); - int r = xpu::gelu(dev_ctx.x_context(), x_data, y_data, x->numel()); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU gelu kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); - } -}; - -template -class GeluGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - auto* dx = ctx.Output(framework::GradVarName("X")); - - auto place = ctx.GetPlace(); - const XPUType* x_data = reinterpret_cast(x->data()); - const XPUType* dout_data = - reinterpret_cast(dout->data()); - XPUType* dx_data = reinterpret_cast(dx->mutable_data(place)); - auto& dev_ctx = ctx.template device_context(); - - int r = xpu::gelu_grad(dev_ctx.x_context(), - x_data, - nullptr, - dout_data, - dx_data, - dout->numel()); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU gelu_grad kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL( - gelu, - ops::GeluXPUKernel, - ops::GeluXPUKernel); - -REGISTER_OP_XPU_KERNEL( - gelu_grad, - ops::GeluGradXPUKernel, - ops::GeluGradXPUKernel); diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index b93096565603803e5ad8fe192c34093631614947..ee4ac7482f211557045376b2d348dc79e344f8b2 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -53,6 +53,22 @@ struct SetConstant { T num); }; +#ifdef PADDLE_WITH_XPU +template +struct SetConstant { + void operator()(const XPUContext& context, + paddle::framework::Tensor* tensor, + T num); +}; + +template +struct SetConstant { + void operator()(const paddle::platform::XPUDeviceContext& context, + paddle::framework::Tensor* tensor, + T num); +}; +#endif + template void set_constant_with_place(const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, diff --git a/paddle/phi/kernels/funcs/math_function_impl.h b/paddle/phi/kernels/funcs/math_function_impl.h index f9055fb56c913bbff413bcff8ad8c80afb155d45..a6aeeb4f63c0dc2b5cedc279fb69ff26f635f3ce 100644 --- a/paddle/phi/kernels/funcs/math_function_impl.h +++ b/paddle/phi/kernels/funcs/math_function_impl.h @@ -27,20 +27,27 @@ using paddle::framework::To32BitIndex; template void SetConstant::operator()( const DeviceContext& context, paddle::framework::Tensor* tensor, T num) { - bool xpu_place = false; + auto t = paddle::framework::EigenVector::Flatten(*tensor); + t.device(*context.eigen_device()) = t.constant(static_cast(num)); +} + #ifdef PADDLE_WITH_XPU - if (paddle::platform::is_xpu_place(context.GetPlace())) { - xpu_place = true; - phi::VisitDataType( - tensor->dtype(), - TensorSetConstantXPU(tensor, num, context.GetPlace())); - } -#endif - if (!xpu_place) { - auto t = paddle::framework::EigenVector::Flatten(*tensor); - t.device(*context.eigen_device()) = t.constant(static_cast(num)); - } +template +void SetConstant::operator()(const XPUContext& context, + paddle::framework::Tensor* tensor, + T num) { + phi::VisitDataType(tensor->dtype(), + TensorSetConstantXPU(tensor, num, context.GetPlace())); +} +template +void SetConstant::operator()( + const paddle::platform::XPUDeviceContext& context, + paddle::framework::Tensor* tensor, + T num) { + phi::VisitDataType(tensor->dtype(), + TensorSetConstantXPU(tensor, num, context.GetPlace())); } +#endif template void Transpose::operator()( diff --git a/paddle/phi/kernels/xpu/gelu_grad_kernel.cc b/paddle/phi/kernels/xpu/gelu_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a99c93ccb58c89fee4f724ab728fd84105ca398 --- /dev/null +++ b/paddle/phi/kernels/xpu/gelu_grad_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gelu_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GeluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + bool approximate, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + + dev_ctx.template Alloc(x_grad); + int r = xpu::gelu_grad( + dev_ctx.x_context(), + reinterpret_cast(x.data()), + nullptr, + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + x_grad->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); +} +} // namespace phi + +PD_REGISTER_KERNEL(gelu_grad, + XPU, + ALL_LAYOUT, + phi::GeluGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/gelu_kernel.cc b/paddle/phi/kernels/xpu/gelu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d69d4ab8dd1415c18d39b816b92c2c6bc686b10e --- /dev/null +++ b/paddle/phi/kernels/xpu/gelu_kernel.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gelu_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GeluKernel(const Context& dev_ctx, + const DenseTensor& x, + bool approximate, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(out); + int r = xpu::gelu(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); +} +} // namespace phi + +PD_REGISTER_KERNEL( + gelu, XPU, ALL_LAYOUT, phi::GeluKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc b/paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a91f5ad9d52f9f34b98e77d530203d2e247b569 --- /dev/null +++ b/paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc @@ -0,0 +1,415 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/generate_proposals_v2_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function_impl.h" + +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +template +static void SortDescending(const XPUContext& dev_ctx, + const DenseTensor& value, + DenseTensor* index_out, + int pre_nms_top_n) { + auto* value_data = value.data(); + auto place = dev_ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + DenseTensor scores_slice_cpu; + scores_slice_cpu.Resize({value.numel()}); + T* scores_slice_cpu_data = dev_ctx.template HostAlloc(&scores_slice_cpu); + + paddle::memory::Copy(cpu_place, + scores_slice_cpu_data, + place, + value_data, + sizeof(T) * value.numel()); + // Sort index + DenseTensor index_t; + index_t.Resize({value.numel()}); + int* index = dev_ctx.template HostAlloc(&index_t); + for (int i = 0; i < value.numel(); ++i) { + index[i] = i; + } + + auto compare = [scores_slice_cpu_data](const int64_t& i, const int64_t& j) { + return scores_slice_cpu_data[i] > scores_slice_cpu_data[j]; + }; + + if (pre_nms_top_n <= 0 || pre_nms_top_n >= value.numel()) { + std::sort(index, index + value.numel(), compare); + } else { + std::nth_element( + index, index + pre_nms_top_n, index + value.numel(), compare); + std::sort(index, index + pre_nms_top_n, compare); + index_t.Resize({pre_nms_top_n}); + } + + index_out->Resize({index_t.numel()}); + int* idx_out = dev_ctx.template Alloc(index_out); + paddle::memory::Copy( + place, idx_out, cpu_place, index, sizeof(T) * index_t.numel()); +} + +template +std::pair ProposalForOneImage( + const phi::XPUContext& dev_ctx, + const DenseTensor& im_shape_slice, + const DenseTensor& anchors, + const DenseTensor& variances, + const DenseTensor& bbox_deltas_slice, // [M, 4] + const DenseTensor& scores_slice, // [N, 1] + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset = true) { + // 1. pre nms + DenseTensor index_sort; + SortDescending(dev_ctx, scores_slice, &index_sort, pre_nms_top_n); + + DenseTensor scores_sel, bbox_sel, anchor_sel, var_sel; + scores_sel.Resize(phi::make_ddim({index_sort.numel(), 1})); + dev_ctx.template Alloc(&scores_sel); + + bbox_sel.Resize(phi::make_ddim({index_sort.numel(), 4})); + dev_ctx.template Alloc(&bbox_sel); + + anchor_sel.Resize(phi::make_ddim({index_sort.numel(), 4})); + dev_ctx.template Alloc(&anchor_sel); + + var_sel.Resize(phi::make_ddim({index_sort.numel(), 4})); + dev_ctx.template Alloc(&var_sel); + + int r = xpu::gather(dev_ctx.x_context(), + scores_slice.data(), + index_sort.data(), + scores_sel.data(), + {static_cast(scores_slice.numel()), 1}, + index_sort.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + r = xpu::gather(dev_ctx.x_context(), + bbox_deltas_slice.data(), + index_sort.data(), + bbox_sel.data(), + {static_cast(bbox_deltas_slice.numel()) / 4, 4}, + index_sort.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + r = xpu::gather(dev_ctx.x_context(), + anchors.data(), + index_sort.data(), + anchor_sel.data(), + {static_cast(anchors.numel()) / 4, 4}, + index_sort.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + r = xpu::gather(dev_ctx.x_context(), + variances.data(), + index_sort.data(), + var_sel.data(), + {static_cast(variances.numel()) / 4, 4}, + index_sort.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + int num = scores_slice.numel(); + int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) + ? scores_slice.numel() + : pre_nms_top_n; + scores_sel.Resize({pre_nms_num, 1}); + index_sort.Resize({pre_nms_num, 1}); + + // 2. box decode and clipping + DenseTensor proposals; + proposals.Resize(phi::make_ddim({index_sort.numel(), 4})); + dev_ctx.template Alloc(&proposals); + + r = xpu::box_decoder(dev_ctx.x_context(), + anchor_sel.data(), + var_sel.data(), + bbox_sel.data(), + proposals.data(), + pre_nms_num, + !pixel_offset, + true, + im_shape_slice.data()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "box_decoder"); + + // 3. filter + DenseTensor keep_index, keep_num_t; + keep_index.Resize(phi::make_ddim({pre_nms_num})); + dev_ctx.template Alloc(&keep_index); + + keep_num_t.Resize(phi::make_ddim({1})); + dev_ctx.template Alloc(&keep_num_t); + min_size = std::max(min_size, 1.0f); + r = xpu::remove_small_boxes(dev_ctx.x_context(), + proposals.data(), + im_shape_slice.data(), + keep_index.data(), + keep_num_t.data(), + pre_nms_num, + min_size, + false, + pixel_offset); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "remove_small_boxes"); + + int keep_num; + const auto xpu_place = dev_ctx.GetPlace(); + paddle::memory::Copy(phi::CPUPlace(), + &keep_num, + xpu_place, + keep_num_t.data(), + sizeof(int)); + keep_index.Resize({keep_num}); + + DenseTensor scores_filter, proposals_filter; + // Handle the case when there is no keep index left + if (keep_num == 0) { + phi::funcs::SetConstant set_zero; + proposals_filter.Resize(phi::make_ddim({1, 4})); + dev_ctx.template Alloc(&proposals_filter); + scores_filter.Resize(phi::make_ddim({1, 1})); + dev_ctx.template Alloc(&scores_filter); + set_zero(dev_ctx, &proposals_filter, static_cast(0)); + set_zero(dev_ctx, &scores_filter, static_cast(0)); + return std::make_pair(proposals_filter, scores_filter); + } + proposals_filter.Resize(phi::make_ddim({keep_num, 4})); + dev_ctx.template Alloc(&proposals_filter); + scores_filter.Resize(phi::make_ddim({keep_num, 1})); + dev_ctx.template Alloc(&scores_filter); + r = xpu::gather(dev_ctx.x_context(), + proposals.data(), + keep_index.data(), + proposals_filter.data(), + {pre_nms_num, 4}, + keep_num, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + r = xpu::gather(dev_ctx.x_context(), + scores_sel.data(), + keep_index.data(), + scores_filter.data(), + {pre_nms_num, 1}, + keep_num, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + + if (nms_thresh <= 0) { + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + return std::make_pair(proposals_filter, scores_filter); + } + + // 4. nms + int nms_keep_num = 0; + r = xpu::sorted_nms(dev_ctx.x_context(), + proposals_filter.data(), + keep_index.data(), + nms_keep_num, + keep_num, + nms_thresh, + pixel_offset); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_nms"); + if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) { + keep_index.Resize({post_nms_top_n}); + } else { + keep_index.Resize({nms_keep_num}); + } + + DenseTensor scores_nms, proposals_nms; + proposals_nms.Resize(phi::make_ddim({keep_index.numel(), 4})); + dev_ctx.template Alloc(&proposals_nms); + scores_nms.Resize(phi::make_ddim({keep_index.numel(), 1})); + dev_ctx.template Alloc(&scores_nms); + r = xpu::gather(dev_ctx.x_context(), + proposals_filter.data(), + keep_index.data(), + proposals_nms.data(), + {keep_num, 4}, + keep_index.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + r = xpu::gather(dev_ctx.x_context(), + scores_filter.data(), + keep_index.data(), + scores_nms.data(), + {keep_num, 1}, + keep_index.numel(), + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather"); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + return std::make_pair(proposals_nms, scores_nms); +} + +template +void GenerateProposalsV2Kernel(const Context& dev_ctx, + const DenseTensor& scores, + const DenseTensor& bbox_deltas, + const DenseTensor& im_shape, + const DenseTensor& anchors, + const DenseTensor& variances, + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset, + DenseTensor* rpn_rois, + DenseTensor* rpn_roi_probs, + DenseTensor* rpn_rois_num) { + PADDLE_ENFORCE_GE(eta, + 1., + phi::errors::InvalidArgument( + "Not support adaptive NMS. The attribute 'eta' " + "should not less than 1. But received eta=[%d]", + eta)); + + auto& scores_dim = scores.dims(); + // the shape of bbox score + int num = scores_dim[0]; + int c_score = scores_dim[1]; + int h_score = scores_dim[2]; + int w_score = scores_dim[3]; + + auto& bbox_dim = bbox_deltas.dims(); + int c_bbox = bbox_dim[1]; + int h_bbox = bbox_dim[2]; + int w_bbox = bbox_dim[3]; + + DenseTensor bbox_deltas_swap, scores_swap; + bbox_deltas_swap.Resize(phi::make_ddim({num, h_bbox, w_bbox, c_bbox})); + dev_ctx.template Alloc(&bbox_deltas_swap); + + scores_swap.Resize(phi::make_ddim({num, h_score, w_score, c_score})); + dev_ctx.template Alloc(&scores_swap); + + std::vector axis = {0, 2, 3, 1}; + int r = xpu::transpose(dev_ctx.x_context(), + bbox_deltas.data(), + bbox_deltas_swap.data(), + {num, c_bbox, h_bbox, w_bbox}, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + r = xpu::transpose(dev_ctx.x_context(), + scores.data(), + scores_swap.data(), + {num, c_score, h_score, w_score}, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + DenseTensor tmp_anchors = anchors; + DenseTensor tmp_variances = variances; + tmp_anchors.Resize(phi::make_ddim({tmp_anchors.numel() / 4, 4})); + tmp_variances.Resize(phi::make_ddim({tmp_variances.numel() / 4, 4})); + + // output + rpn_rois->Resize(phi::make_ddim({bbox_deltas.numel() / 4, 4})); + dev_ctx.template Alloc(rpn_rois); + + rpn_roi_probs->Resize(phi::make_ddim({scores.numel(), 1})); + dev_ctx.template Alloc(rpn_roi_probs); + + auto place = dev_ctx.GetPlace(); + auto cpu_place = phi::CPUPlace(); + + int num_proposals = 0; + std::vector offset(1, 0); + std::vector tmp_num; + + for (int64_t i = 0; i < num; ++i) { + DenseTensor im_shape_slice = im_shape.Slice(i, i + 1); + DenseTensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); + DenseTensor scores_slice = scores_swap.Slice(i, i + 1); + + bbox_deltas_slice.Resize(phi::make_ddim({h_bbox * w_bbox * c_bbox / 4, 4})); + scores_slice.Resize(phi::make_ddim({h_score * w_score * c_score, 1})); + + std::pair tensor_pair = + ProposalForOneImage(dev_ctx, + im_shape_slice, + tmp_anchors, + tmp_variances, + bbox_deltas_slice, + scores_slice, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + eta, + pixel_offset); + + DenseTensor& proposals = tensor_pair.first; + DenseTensor& nscores = tensor_pair.second; + + paddle::memory::Copy(place, + rpn_rois->data() + num_proposals * 4, + place, + proposals.data(), + sizeof(T) * proposals.numel()); + paddle::memory::Copy(place, + rpn_roi_probs->data() + num_proposals, + place, + nscores.data(), + sizeof(T) * scores.numel()); + + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + num_proposals += proposals.dims()[0]; + offset.emplace_back(num_proposals); + tmp_num.push_back(proposals.dims()[0]); + } + + if (rpn_rois_num != nullptr) { + rpn_rois_num->Resize(phi::make_ddim({num})); + dev_ctx.template Alloc(rpn_rois_num); + int* num_data = rpn_rois_num->data(); + paddle::memory::Copy( + place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num); + } + + phi::LoD lod; + lod.emplace_back(offset); + rpn_rois->set_lod(lod); + rpn_roi_probs->set_lod(lod); + rpn_rois->Resize(phi::make_ddim({num_proposals, 4})); + rpn_roi_probs->Resize(phi::make_ddim({num_proposals, 1})); +} +} // namespace phi + +PD_REGISTER_KERNEL(generate_proposals_v2, + XPU, + ALL_LAYOUT, + phi::GenerateProposalsV2Kernel, + float) {}