未验证 提交 8b24c795 编写于 作者: L Leo Chen 提交者: GitHub

move gelu/gelu_grad/generate_proposals_v2 kernel to phi (#45471)

* move xpu kernel to phi

* delete fluid file

* fix compile

* add guard, test=kunlun

* xpu set constant

* fix xpu error, test=kunlun
上级 c857841e
......@@ -42,8 +42,7 @@ if(WITH_XPU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_xpu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_xpu.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc
generate_proposals_v2_op_xpu.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
elseif(WITH_MLU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_mlu.cc)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class GeluXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
const XPUType* x_data = reinterpret_cast<const XPUType*>(x->data<T>());
XPUType* y_data = reinterpret_cast<XPUType*>(out->mutable_data<T>(place));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::gelu<XPUType>(dev_ctx.x_context(), x_data, y_data, x->numel());
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU gelu kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
class GeluGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
const XPUType* x_data = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* dout_data =
reinterpret_cast<const XPUType*>(dout->data<T>());
XPUType* dx_data = reinterpret_cast<XPUType*>(dx->mutable_data<T>(place));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::gelu_grad<XPUType>(dev_ctx.x_context(),
x_data,
nullptr,
dout_data,
dx_data,
dout->numel());
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU gelu_grad kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
gelu,
ops::GeluXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GeluXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
gelu_grad,
ops::GeluGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GeluGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
......@@ -53,6 +53,22 @@ struct SetConstant {
T num);
};
#ifdef PADDLE_WITH_XPU
template <typename T>
struct SetConstant<XPUContext, T> {
void operator()(const XPUContext& context,
paddle::framework::Tensor* tensor,
T num);
};
template <typename T>
struct SetConstant<paddle::platform::XPUDeviceContext, T> {
void operator()(const paddle::platform::XPUDeviceContext& context,
paddle::framework::Tensor* tensor,
T num);
};
#endif
template <typename Place>
void set_constant_with_place(const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor,
......
......@@ -27,21 +27,28 @@ using paddle::framework::To32BitIndex;
template <typename DeviceContext, typename T>
void SetConstant<DeviceContext, T>::operator()(
const DeviceContext& context, paddle::framework::Tensor* tensor, T num) {
bool xpu_place = false;
#ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(context.GetPlace())) {
xpu_place = true;
phi::VisitDataType(
tensor->dtype(),
TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
}
#endif
if (!xpu_place) {
auto t = paddle::framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
}
}
#ifdef PADDLE_WITH_XPU
template <typename T>
void SetConstant<XPUContext, T>::operator()(const XPUContext& context,
paddle::framework::Tensor* tensor,
T num) {
phi::VisitDataType(tensor->dtype(),
TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
}
template <typename T>
void SetConstant<paddle::platform::XPUDeviceContext, T>::operator()(
const paddle::platform::XPUDeviceContext& context,
paddle::framework::Tensor* tensor,
T num) {
phi::VisitDataType(tensor->dtype(),
TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
}
#endif
template <typename DeviceContext, typename T, int Rank>
void Transpose<DeviceContext, T, Rank>::operator()(
const DeviceContext& context,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad);
int r = xpu::gelu_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
nullptr,
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
x_grad->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(gelu_grad,
XPU,
ALL_LAYOUT,
phi::GeluGradKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
int r = xpu::gelu<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu");
}
} // namespace phi
PD_REGISTER_KERNEL(
gelu, XPU, ALL_LAYOUT, phi::GeluKernel, float, phi::dtype::float16) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/generate_proposals_v2_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/fluid/memory/memcpy.h"
namespace phi {
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <paddle/fluid/memory/allocation/allocator.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
namespace {
template <typename T>
static void SortDescending(const platform::XPUDeviceContext &dev_ctx,
const Tensor &value,
Tensor *index_out,
static void SortDescending(const XPUContext& dev_ctx,
const DenseTensor& value,
DenseTensor* index_out,
int pre_nms_top_n) {
auto *value_data = value.data<T>();
auto* value_data = value.data<T>();
auto place = dev_ctx.GetPlace();
auto cpu_place = platform::CPUPlace();
auto cpu_place = phi::CPUPlace();
Tensor scores_slice_cpu;
DenseTensor scores_slice_cpu;
scores_slice_cpu.Resize({value.numel()});
auto *scores_slice_cpu_data = scores_slice_cpu.mutable_data<T>(cpu_place);
T* scores_slice_cpu_data = dev_ctx.template HostAlloc<T>(&scores_slice_cpu);
memory::Copy(cpu_place,
paddle::memory::Copy(cpu_place,
scores_slice_cpu_data,
place,
value_data,
sizeof(T) * value.numel());
// Sort index
Tensor index_t;
int *index = index_t.mutable_data<int>({value.numel()}, cpu_place);
DenseTensor index_t;
index_t.Resize({value.numel()});
int* index = dev_ctx.template HostAlloc<int>(&index_t);
for (int i = 0; i < value.numel(); ++i) {
index[i] = i;
}
auto compare = [scores_slice_cpu_data](const int64_t &i, const int64_t &j) {
auto compare = [scores_slice_cpu_data](const int64_t& i, const int64_t& j) {
return scores_slice_cpu_data[i] > scores_slice_cpu_data[j];
};
......@@ -70,49 +63,57 @@ static void SortDescending(const platform::XPUDeviceContext &dev_ctx,
index_t.Resize({pre_nms_top_n});
}
int *idx_out =
index_out->mutable_data<int>({index_t.numel()}, dev_ctx.GetPlace());
memory::Copy(place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
index_out->Resize({index_t.numel()});
int* idx_out = dev_ctx.template Alloc<int>(index_out);
paddle::memory::Copy(
place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
}
template <typename T>
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::XPUDeviceContext &dev_ctx,
const Tensor &im_shape,
const Tensor &anchors,
const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
const Tensor &scores, // [N, 1]
std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
const phi::XPUContext& dev_ctx,
const DenseTensor& im_shape_slice,
const DenseTensor& anchors,
const DenseTensor& variances,
const DenseTensor& bbox_deltas_slice, // [M, 4]
const DenseTensor& scores_slice, // [N, 1]
int pre_nms_top_n,
int post_nms_top_n,
float nms_thresh,
float min_size,
float eta,
bool pixel_offset) {
bool pixel_offset = true) {
// 1. pre nms
Tensor index_sort;
SortDescending<T>(dev_ctx, scores, &index_sort, pre_nms_top_n);
DenseTensor index_sort;
SortDescending<T>(dev_ctx, scores_slice, &index_sort, pre_nms_top_n);
DenseTensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.Resize(phi::make_ddim({index_sort.numel(), 1}));
dev_ctx.template Alloc<T>(&scores_sel);
bbox_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
dev_ctx.template Alloc<T>(&bbox_sel);
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_sort.numel(), 1}, dev_ctx.GetPlace());
bbox_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
anchor_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
var_sel.mutable_data<T>({index_sort.numel(), 4}, dev_ctx.GetPlace());
anchor_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
dev_ctx.template Alloc<T>(&anchor_sel);
var_sel.Resize(phi::make_ddim({index_sort.numel(), 4}));
dev_ctx.template Alloc<T>(&var_sel);
int r = xpu::gather<T>(dev_ctx.x_context(),
scores.data<T>(),
scores_slice.data<T>(),
index_sort.data<int>(),
scores_sel.data<T>(),
{static_cast<int>(scores.numel()), 1},
{static_cast<int>(scores_slice.numel()), 1},
index_sort.numel(),
0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
r = xpu::gather<T>(dev_ctx.x_context(),
bbox_deltas.data<T>(),
bbox_deltas_slice.data<T>(),
index_sort.data<int>(),
bbox_sel.data<T>(),
{static_cast<int>(bbox_deltas.numel()) / 4, 4},
{static_cast<int>(bbox_deltas_slice.numel()) / 4, 4},
index_sort.numel(),
0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
......@@ -135,15 +136,17 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
int num = scores.numel();
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel()
int num = scores_slice.numel();
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num)
? scores_slice.numel()
: pre_nms_top_n;
scores_sel.Resize({pre_nms_num, 1});
index_sort.Resize({pre_nms_num, 1});
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, dev_ctx.GetPlace());
DenseTensor proposals;
proposals.Resize(phi::make_ddim({index_sort.numel(), 4}));
dev_ctx.template Alloc<T>(&proposals);
r = xpu::box_decoder<T>(dev_ctx.x_context(),
anchor_sel.data<T>(),
......@@ -153,17 +156,20 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
pre_nms_num,
!pixel_offset,
true,
im_shape.data<T>());
im_shape_slice.data<T>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "box_decoder");
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, dev_ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, dev_ctx.GetPlace());
DenseTensor keep_index, keep_num_t;
keep_index.Resize(phi::make_ddim({pre_nms_num}));
dev_ctx.template Alloc<int>(&keep_index);
keep_num_t.Resize(phi::make_ddim({1}));
dev_ctx.template Alloc<int>(&keep_num_t);
min_size = std::max(min_size, 1.0f);
r = xpu::remove_small_boxes<T>(dev_ctx.x_context(),
proposals.data<T>(),
im_shape.data<T>(),
im_shape_slice.data<T>(),
keep_index.data<int>(),
keep_num_t.data<int>(),
pre_nms_num,
......@@ -171,27 +177,32 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
false,
pixel_offset);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "remove_small_boxes");
int keep_num;
const auto xpu_place = dev_ctx.GetPlace();
memory::Copy(platform::CPUPlace(),
paddle::memory::Copy(phi::CPUPlace(),
&keep_num,
xpu_place,
keep_num_t.data<int>(),
sizeof(int));
keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter;
DenseTensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
phi::funcs::SetConstant<platform::XPUDeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, dev_ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, dev_ctx.GetPlace());
phi::funcs::SetConstant<phi::XPUContext, T> set_zero;
proposals_filter.Resize(phi::make_ddim({1, 4}));
dev_ctx.template Alloc<T>(&proposals_filter);
scores_filter.Resize(phi::make_ddim({1, 1}));
dev_ctx.template Alloc<T>(&scores_filter);
set_zero(dev_ctx, &proposals_filter, static_cast<T>(0));
set_zero(dev_ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, dev_ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, dev_ctx.GetPlace());
proposals_filter.Resize(phi::make_ddim({keep_num, 4}));
dev_ctx.template Alloc<T>(&proposals_filter);
scores_filter.Resize(phi::make_ddim({keep_num, 1}));
dev_ctx.template Alloc<T>(&scores_filter);
r = xpu::gather<T>(dev_ctx.x_context(),
proposals.data<T>(),
keep_index.data<int>(),
......@@ -233,9 +244,11 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
keep_index.Resize({nms_keep_num});
}
Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_index.numel(), 4}, dev_ctx.GetPlace());
scores_nms.mutable_data<T>({keep_index.numel(), 1}, dev_ctx.GetPlace());
DenseTensor scores_nms, proposals_nms;
proposals_nms.Resize(phi::make_ddim({keep_index.numel(), 4}));
dev_ctx.template Alloc<T>(&proposals_nms);
scores_nms.Resize(phi::make_ddim({keep_index.numel(), 1}));
dev_ctx.template Alloc<T>(&scores_nms);
r = xpu::gather<T>(dev_ctx.x_context(),
proposals_filter.data<T>(),
keep_index.data<int>(),
......@@ -257,105 +270,96 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
}
return std::make_pair(proposals_nms, scores_nms);
}
} // namespace
template <typename DeviceContext, typename T>
class XPUGenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"),
"Input",
"Anchors",
"GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input",
"Variances",
"GenerateProposals");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
bool pixel_offset = context.Attr<bool>("pixel_offset");
template <typename T, typename Context>
void GenerateProposalsV2Kernel(const Context& dev_ctx,
const DenseTensor& scores,
const DenseTensor& bbox_deltas,
const DenseTensor& im_shape,
const DenseTensor& anchors,
const DenseTensor& variances,
int pre_nms_top_n,
int post_nms_top_n,
float nms_thresh,
float min_size,
float eta,
bool pixel_offset,
DenseTensor* rpn_rois,
DenseTensor* rpn_roi_probs,
DenseTensor* rpn_rois_num) {
PADDLE_ENFORCE_GE(eta,
1.,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Not support adaptive NMS. The attribute 'eta' "
"should not less than 1. But received eta=[%d]",
eta));
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
auto& scores_dim = scores.dims();
// the shape of bbox score
int num = scores_dim[0];
int c_score = scores_dim[1];
int h_score = scores_dim[2];
int w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
auto& bbox_dim = bbox_deltas.dims();
int c_bbox = bbox_dim[1];
int h_bbox = bbox_dim[2];
int w_bbox = bbox_dim[3];
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
DenseTensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.Resize(phi::make_ddim({num, h_bbox, w_bbox, c_bbox}));
dev_ctx.template Alloc<T>(&bbox_deltas_swap);
scores_swap.Resize(phi::make_ddim({num, h_score, w_score, c_score}));
dev_ctx.template Alloc<T>(&scores_swap);
std::vector<int> axis = {0, 2, 3, 1};
int r = xpu::transpose<T>(dev_ctx.x_context(),
bbox_deltas->data<T>(),
bbox_deltas.data<T>(),
bbox_deltas_swap.data<T>(),
{num, c_bbox, h_bbox, w_bbox},
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::transpose<T>(dev_ctx.x_context(),
scores->data<T>(),
scores.data<T>(),
scores_swap.data<T>(),
{num, c_score, h_score, w_score},
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
DenseTensor tmp_anchors = anchors;
DenseTensor tmp_variances = variances;
tmp_anchors.Resize(phi::make_ddim({tmp_anchors.numel() / 4, 4}));
tmp_variances.Resize(phi::make_ddim({tmp_variances.numel() / 4, 4}));
// output
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
rpn_rois->Resize(phi::make_ddim({bbox_deltas.numel() / 4, 4}));
dev_ctx.template Alloc<T>(rpn_rois);
T *rpn_rois_data = rpn_rois->data<T>();
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
rpn_roi_probs->Resize(phi::make_ddim({scores.numel(), 1}));
dev_ctx.template Alloc<T>(rpn_roi_probs);
auto place = dev_ctx.GetPlace();
auto cpu_place = platform::CPUPlace();
auto cpu_place = phi::CPUPlace();
int num_proposals = 0;
std::vector<size_t> offset(1, 0);
std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
DenseTensor im_shape_slice = im_shape.Slice(i, i + 1);
DenseTensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
DenseTensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
bbox_deltas_slice.Resize(phi::make_ddim({h_bbox * w_bbox * c_bbox / 4, 4}));
scores_slice.Resize(phi::make_ddim({h_score * w_score * c_score, 1}));
std::pair<Tensor, Tensor> box_score_pair =
std::pair<DenseTensor, DenseTensor> tensor_pair =
ProposalForOneImage<T>(dev_ctx,
im_shape_slice,
anchors,
variances,
tmp_anchors,
tmp_variances,
bbox_deltas_slice,
scores_slice,
pre_nms_top_n,
......@@ -365,19 +369,20 @@ class XPUGenerateProposalsV2Kernel : public framework::OpKernel<T> {
eta,
pixel_offset);
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
DenseTensor& proposals = tensor_pair.first;
DenseTensor& nscores = tensor_pair.second;
memory::Copy(place,
rpn_rois_data + num_proposals * 4,
paddle::memory::Copy(place,
rpn_rois->data<T>() + num_proposals * 4,
place,
proposals.data<T>(),
sizeof(T) * proposals.numel());
memory::Copy(place,
rpn_roi_probs_data + num_proposals,
paddle::memory::Copy(place,
rpn_roi_probs->data<T>() + num_proposals,
place,
scores.data<T>(),
nscores.data<T>(),
sizeof(T) * scores.numel());
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
......@@ -385,29 +390,26 @@ class XPUGenerateProposalsV2Kernel : public framework::OpKernel<T> {
offset.emplace_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num);
rpn_rois_num->Resize({num});
if (rpn_rois_num != nullptr) {
rpn_rois_num->Resize(phi::make_ddim({num}));
dev_ctx.template Alloc<int>(rpn_rois_num);
int* num_data = rpn_rois_num->data<int>();
paddle::memory::Copy(
place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num);
}
framework::LoD lod;
phi::LoD lod;
lod.emplace_back(offset);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
generate_proposals_v2,
ops::XPUGenerateProposalsV2Kernel<paddle::platform::XPUDeviceContext,
float>);
rpn_rois->Resize(phi::make_ddim({num_proposals, 4}));
rpn_roi_probs->Resize(phi::make_ddim({num_proposals, 1}));
}
} // namespace phi
#endif // PADDLE_WITH_XPU
PD_REGISTER_KERNEL(generate_proposals_v2,
XPU,
ALL_LAYOUT,
phi::GenerateProposalsV2Kernel,
float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册